Commit 1697b6bb authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent 0f1dc7bc
......@@ -6,5 +6,21 @@
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"expected": [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
},
{
"name": "sub",
"output": [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"expected": [[0, 0, -4, -3, -3, 0], [-2, -4, -4, 0, 0, 0]]
},
{
"name": "mul",
"output": [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]],
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"expected": [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]]
}
]
......@@ -51,7 +51,7 @@ def scatter_sub_(output, index, input, dim=0):
-2 -4 -4 0 0 0
[torch.FloatTensor of size 2x6]
"""
return output.scatter_add_(dim, index, -input)
return output.scatter_add_(dim, index, -1 * input)
def scatter_sub(index, input, dim=0, size=None, fill_value=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment