Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
102b7542
Commit
102b7542
authored
Dec 22, 2017
by
rusty1s
Browse files
typos
parent
46c2a5cb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
17 deletions
+17
-17
torch_scatter/functions/max.py
torch_scatter/functions/max.py
+2
-2
torch_scatter/functions/mean.py
torch_scatter/functions/mean.py
+4
-4
torch_scatter/functions/min.py
torch_scatter/functions/min.py
+2
-2
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+9
-9
No files found.
torch_scatter/functions/max.py
View file @
102b7542
...
@@ -50,8 +50,8 @@ def scatter_max_(output, index, input, dim=0):
...
@@ -50,8 +50,8 @@ def scatter_max_(output, index, input, dim=0):
[torch.LongTensor of size 2x6]
[torch.LongTensor of size 2x6]
)
)
"""
"""
arg
_output
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
arg
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
arg
_output
)
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
arg
)
def
scatter_max
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
def
scatter_max
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/functions/mean.py
View file @
102b7542
...
@@ -44,10 +44,10 @@ def scatter_mean_(output, index, input, dim=0):
...
@@ -44,10 +44,10 @@ def scatter_mean_(output, index, input, dim=0):
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
[torch.FloatTensor of size 2x6]
"""
"""
num_outpu
t
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
coun
t
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
num_outpu
t
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
coun
t
)
num_output
[
num_outpu
t
==
0
]
=
1
count
[
coun
t
==
0
]
=
1
output
/=
num_outpu
t
output
/=
coun
t
return
output
return
output
...
...
torch_scatter/functions/min.py
View file @
102b7542
...
@@ -50,8 +50,8 @@ def scatter_min_(output, index, input, dim=0):
...
@@ -50,8 +50,8 @@ def scatter_min_(output, index, input, dim=0):
[torch.LongTensor of size 2x6]
[torch.LongTensor of size 2x6]
)
)
"""
"""
arg
_output
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
arg
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
arg
_output
)
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
arg
)
def
scatter_min
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
def
scatter_min
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/functions/scatter.py
View file @
102b7542
...
@@ -6,7 +6,7 @@ from torch.autograd import Function
...
@@ -6,7 +6,7 @@ from torch.autograd import Function
from
.._ext
import
ffi
from
.._ext
import
ffi
def
has_arg
_output
(
name
):
def
has_arg
(
name
):
return
name
in
[
'max'
,
'min'
]
return
name
in
[
'max'
,
'min'
]
...
@@ -36,7 +36,7 @@ def _scatter(name, dim, *data):
...
@@ -36,7 +36,7 @@ def _scatter(name, dim, *data):
cuda
=
'cuda_'
if
data
[
0
].
is_cuda
else
''
cuda
=
'cuda_'
if
data
[
0
].
is_cuda
else
''
func
=
getattr
(
ffi
,
'scatter_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
func
=
getattr
(
ffi
,
'scatter_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
func
(
dim
,
*
data
)
func
(
dim
,
*
data
)
return
(
data
[
0
],
data
[
3
])
if
has_arg
_output
(
name
)
else
data
[
0
]
return
(
data
[
0
],
data
[
3
])
if
has_arg
(
name
)
else
data
[
0
]
def
index_backward
(
dim
,
index
,
grad
,
arg
):
def
index_backward
(
dim
,
index
,
grad
,
arg
):
...
@@ -63,9 +63,9 @@ class _Scatter(Function):
...
@@ -63,9 +63,9 @@ class _Scatter(Function):
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`.
In addition
, we need to save the
# respectively `argmin`.
Therefore
, we need to save the
`arg` for the
#
`arg_output` for the
backward pass.
# backward pass.
if
has_arg
_output
(
self
.
name
):
if
has_arg
(
self
.
name
):
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
return
data
[
0
],
data
[
3
]
return
data
[
0
],
data
[
3
]
else
:
else
:
...
@@ -80,13 +80,13 @@ class _Scatter(Function):
...
@@ -80,13 +80,13 @@ class _Scatter(Function):
# Different grad computation of `input` if `scatter_max` or
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
# `scatter_min` was used.
if
self
.
needs_input_grad
[
2
]
and
not
has_arg
_output
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
not
has_arg
(
self
.
name
):
index
,
=
self
.
saved_variables
index
,
=
self
.
saved_variables
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
if
self
.
needs_input_grad
[
2
]
and
has_arg
_output
(
self
.
name
):
if
self
.
needs_input_grad
[
2
]
and
has_arg
(
self
.
name
):
index
,
arg
_grad
=
self
.
saved_variables
index
,
arg
=
self
.
saved_variables
data
=
(
index
.
data
,
data
[
0
],
arg
_grad
.
data
)
data
=
(
index
.
data
,
data
[
0
],
arg
.
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
# Return and fill with empty grads for none-differentiable passed
# Return and fill with empty grads for none-differentiable passed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment