Commit 0a9f541c authored by rusty1s's avatar rusty1s
Browse files

fix backward for max min

parent e6821e37
......@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
]
__version__ = '1.3.0'
__version__ = '1.3.1'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = []
......
......@@ -7,7 +7,7 @@ from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
__version__ = '1.3.0'
__version__ = '1.3.1'
__all__ = [
'scatter_add',
......
......@@ -24,8 +24,11 @@ class ScatterMax(Function):
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size())
grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None
......
......@@ -24,8 +24,11 @@ class ScatterMin(Function):
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size())
grad_src.scatter_(ctx.dim, arg.detach(), grad_out)
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment