Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
8c48dee0
Commit
8c48dee0
authored
Dec 18, 2017
by
rusty1s
Browse files
comments
parent
5ba5c620
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
1 deletion
+8
-1
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+8
-1
No files found.
torch_scatter/functions/scatter.py
View file @
8c48dee0
...
@@ -56,10 +56,13 @@ class _Scatter(Function):
...
@@ -56,10 +56,13 @@ class _Scatter(Function):
assert
not
self
.
needs_input_grad
[
1
],
'Can
\'
t differentiate the index'
assert
not
self
.
needs_input_grad
[
1
],
'Can
\'
t differentiate the index'
self
.
mark_dirty
(
data
[
0
])
# Mark output as dirty.
self
.
mark_dirty
(
data
[
0
])
# Mark output as dirty.
self
.
len
=
len
(
data
)
# Save number of arguments for backward step
self
.
len
=
len
(
data
)
# Save number of arguments for backward step
.
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass.
if
_has_output_index
(
self
.
name
):
if
_has_output_index
(
self
.
name
):
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
return
data
[
0
],
data
[
3
]
return
data
[
0
],
data
[
3
]
...
@@ -73,6 +76,8 @@ class _Scatter(Function):
...
@@ -73,6 +76,8 @@ class _Scatter(Function):
if
self
.
needs_input_grad
[
0
]:
if
self
.
needs_input_grad
[
0
]:
grad_output
=
data
[
0
]
grad_output
=
data
[
0
]
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if
self
.
needs_input_grad
[
2
]
and
not
_has_output_index
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
not
_has_output_index
(
self
.
name
):
index
,
=
self
.
saved_variables
index
,
=
self
.
saved_variables
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
...
@@ -82,6 +87,8 @@ class _Scatter(Function):
...
@@ -82,6 +87,8 @@ class _Scatter(Function):
data
=
(
index
.
data
,
data
[
0
],
grad_index
.
data
)
data
=
(
index
.
data
,
data
[
0
],
grad_index
.
data
)
grad_input
=
_index_backward
(
self
.
dim
,
*
data
)
grad_input
=
_index_backward
(
self
.
dim
,
*
data
)
# Return and fill with empty grads for none-differentiable passed
# arguments in forward pass.
return
(
grad_output
,
None
,
grad_input
)
+
(
None
,
)
*
(
self
.
len
-
3
)
return
(
grad_output
,
None
,
grad_input
)
+
(
None
,
)
*
(
self
.
len
-
3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment