Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
f655f536
Commit
f655f536
authored
Dec 18, 2017
by
rusty1s
Browse files
typo
parent
c541e366
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+7
-7
No files found.
torch_scatter/functions/scatter.py
View file @
f655f536
...
...
@@ -6,7 +6,7 @@ from torch.autograd import Function
from
.._ext
import
ffi
def
_
has_output_index
(
name
):
def
has_output_index
(
name
):
return
name
in
[
'max'
,
'min'
]
...
...
@@ -35,10 +35,10 @@ def _scatter(name, dim, *data):
typename
=
type
(
data
[
0
]).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'scatter_{}_{}'
.
format
(
name
,
typename
))
func
(
dim
,
*
data
)
return
(
data
[
0
],
data
[
3
])
if
_
has_output_index
(
name
)
else
data
[
0
]
return
(
data
[
0
],
data
[
3
])
if
has_output_index
(
name
)
else
data
[
0
]
def
_
index_backward
(
dim
,
index
,
grad
,
grad_index
):
def
index_backward
(
dim
,
index
,
grad
,
grad_index
):
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'index_backward_{}'
.
format
(
typename
))
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
...
...
@@ -63,7 +63,7 @@ class _Scatter(Function):
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. In addition, we need to save the
# `output_index` for the backward pass.
if
_
has_output_index
(
self
.
name
):
if
has_output_index
(
self
.
name
):
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
return
data
[
0
],
data
[
3
]
else
:
...
...
@@ -78,14 +78,14 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if
self
.
needs_input_grad
[
2
]
and
not
_
has_output_index
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
not
has_output_index
(
self
.
name
):
index
,
=
self
.
saved_variables
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
if
self
.
needs_input_grad
[
2
]
and
_
has_output_index
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
has_output_index
(
self
.
name
):
index
,
grad_index
=
self
.
saved_variables
data
=
(
index
.
data
,
data
[
0
],
grad_index
.
data
)
grad_input
=
_
index_backward
(
self
.
dim
,
*
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
# Return and fill with empty grads for none-differentiable passed
# arguments in forward pass.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment