Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0a9f541c
Commit
0a9f541c
authored
Jul 05, 2019
by
rusty1s
Browse files
fix backward for max min
parent
e6821e37
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
6 deletions
+12
-6
setup.py
setup.py
+1
-1
torch_scatter/__init__.py
torch_scatter/__init__.py
+1
-1
torch_scatter/max.py
torch_scatter/max.py
+5
-2
torch_scatter/min.py
torch_scatter/min.py
+5
-2
No files found.
setup.py
View file @
0a9f541c
...
...
@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
[
'cuda/scatter.cpp'
,
'cuda/scatter_kernel.cu'
])
]
__version__
=
'1.3.
0
'
__version__
=
'1.3.
1
'
url
=
'https://github.com/rusty1s/pytorch_scatter'
install_requires
=
[]
...
...
torch_scatter/__init__.py
View file @
0a9f541c
...
...
@@ -7,7 +7,7 @@ from .std import scatter_std
from
.max
import
scatter_max
from
.min
import
scatter_min
__version__
=
'1.3.
0
'
__version__
=
'1.3.
1
'
__all__
=
[
'scatter_add'
,
...
...
torch_scatter/max.py
View file @
0a9f541c
...
...
@@ -24,8 +24,11 @@ class ScatterMax(Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
grad_out
.
new_zeros
(
index
.
size
())
grad_src
.
scatter_
(
ctx
.
dim
,
arg
.
detach
(),
grad_out
)
size
=
list
(
index
.
size
())
size
[
ctx
.
dim
]
+=
1
grad_src
=
grad_out
.
new_zeros
(
size
)
grad_src
.
scatter_
(
ctx
.
dim
,
arg
.
detach
()
+
1
,
grad_out
)
grad_src
=
grad_src
.
narrow
(
ctx
.
dim
,
1
,
index
.
size
(
ctx
.
dim
))
return
None
,
grad_src
,
None
,
None
...
...
torch_scatter/min.py
View file @
0a9f541c
...
...
@@ -24,8 +24,11 @@ class ScatterMin(Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
grad_out
.
new_zeros
(
index
.
size
())
grad_src
.
scatter_
(
ctx
.
dim
,
arg
.
detach
(),
grad_out
)
size
=
list
(
index
.
size
())
size
[
ctx
.
dim
]
+=
1
grad_src
=
grad_out
.
new_zeros
(
size
)
grad_src
.
scatter_
(
ctx
.
dim
,
arg
.
detach
()
+
1
,
grad_out
)
grad_src
=
grad_src
.
narrow
(
ctx
.
dim
,
1
,
index
.
size
(
ctx
.
dim
))
return
None
,
grad_src
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment