Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
e3940621
Commit
e3940621
authored
Dec 16, 2017
by
rusty1s
Browse files
grad test
parent
29e28847
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
14 deletions
+4
-14
test/test_add.py
test/test_add.py
+4
-14
No files found.
test/test_add.py
View file @
e3940621
...
...
@@ -24,18 +24,8 @@ def test_scatter_add():
input
=
Variable
(
input
,
requires_grad
=
True
)
scatter_add_
(
output
,
index
,
input
,
dim
=
1
)
c
=
output
.
sum
()
c
.
backward
(
)
grad_
output
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
]]
grad_output
=
torch
.
FloatTensor
(
grad_output
)
# # a = input * 2
# # b = output * 2
# a = input * 2
# b = output * 2
# ScatterAdd(1)(b, index, a)
# # b.scatter_add_(1, index, a)
# c = b.sum()
# c.backward()
# print(input.grad)
# print(output.grad)
output
.
backward
(
grad_output
)
assert_equal
(
index
.
data
.
tolist
(),
input
.
grad
.
data
.
tolist
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment