Commit af375340 authored by rusty1s's avatar rusty1s
Browse files

fixed div grad

parent 61eb4d03
...@@ -12,7 +12,7 @@ class ScatterDiv(Scatter): # pragma: no cover ...@@ -12,7 +12,7 @@ class ScatterDiv(Scatter): # pragma: no cover
def backward_step(self, *data): def backward_step(self, *data):
grad, output, index, input = data grad, output, index, input = data
return (grad / output.data).gather(self.dim, index.data) * input.data return (output.data / grad).gather(self.dim, index.data) * input.data
def scatter_div_(output, index, input, dim=0): def scatter_div_(output, index, input, dim=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment