Commit 0a9f541c authored by rusty1s's avatar rusty1s
Browse files

fix backward for max min

parent e6821e37
...@@ -20,7 +20,7 @@ if CUDA_HOME is not None: ...@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
] ]
__version__ = '1.3.0' __version__ = '1.3.1'
url = 'https://github.com/rusty1s/pytorch_scatter' url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = [] install_requires = []
......
...@@ -7,7 +7,7 @@ from .std import scatter_std ...@@ -7,7 +7,7 @@ from .std import scatter_std
from .max import scatter_max from .max import scatter_max
from .min import scatter_min from .min import scatter_min
__version__ = '1.3.0' __version__ = '1.3.1'
__all__ = [ __all__ = [
'scatter_add', 'scatter_add',
......
...@@ -24,8 +24,11 @@ class ScatterMax(Function): ...@@ -24,8 +24,11 @@ class ScatterMax(Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size()) size = list(index.size())
grad_src.scatter_(ctx.dim, arg.detach(), grad_out) size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None return None, grad_src, None, None
......
...@@ -24,8 +24,11 @@ class ScatterMin(Function): ...@@ -24,8 +24,11 @@ class ScatterMin(Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size()) size = list(index.size())
grad_src.scatter_(ctx.dim, arg.detach(), grad_out) size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None return None, grad_src, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment