Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
fdba2556
Commit
fdba2556
authored
Dec 16, 2017
by
rusty1s
Browse files
mean grad fix
parent
45d8a4ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
21 deletions
+28
-21
test/test_mean.py
test/test_mean.py
+16
-18
torch_scatter/functions/__init__.py
torch_scatter/functions/__init__.py
+8
-2
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+4
-1
No files found.
test/test_mean.py
View file @
fdba2556
...
@@ -6,30 +6,28 @@ from torch_scatter import scatter_mean_, scatter_mean
...
@@ -6,30 +6,28 @@ from torch_scatter import scatter_mean_, scatter_mean
from
.utils
import
tensor_strs
,
Tensor
from
.utils
import
tensor_strs
,
Tensor
# @pytest.mark.parametrize('str', tensor_strs)
@
pytest
.
mark
.
parametrize
(
'str'
,
tensor_strs
)
# def test_scatter_add(str):
def
test_scatter_mean
(
str
):
def
test_scatter_mean
():
input
=
[[
2
,
0
,
8
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
input
=
torch
.
Float
Tensor
(
input
)
input
=
Tensor
(
str
,
input
)
index
=
torch
.
LongTensor
(
index
)
index
=
torch
.
LongTensor
(
index
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
#
expected_output = [[0, 0, 4, 3,
3
, 0], [
2
, 4,
4
, 0, 0, 0]]
expected_output
=
[[
0
,
0
,
4
,
3
,
5
,
0
],
[
1
,
4
,
2
,
0
,
0
,
0
]]
scatter_mean_
(
output
,
index
,
input
,
dim
=
1
)
scatter_mean_
(
output
,
index
,
input
,
dim
=
1
)
print
(
output
)
assert
output
.
tolist
()
==
expected_output
# assert output.tolist() == expected_output
#
output = scatter_
add
(index, input, dim=1)
output
=
scatter_
mean
(
index
,
input
,
dim
=
1
)
#
assert output.tolist(), expected_output
assert
output
.
tolist
(),
expected_output
#
output = Variable(output).fill_(0)
output
=
Variable
(
output
).
fill_
(
0
)
#
index = Variable(index)
index
=
Variable
(
index
)
#
input = Variable(input, requires_grad=True)
input
=
Variable
(
input
,
requires_grad
=
True
)
#
scatter_
add
_(output, index, input, dim=1)
scatter_
mean
_
(
output
,
index
,
input
,
dim
=
1
)
#
grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
grad_output
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
]]
#
grad_output = Tensor(str, grad_output)
grad_output
=
Tensor
(
str
,
grad_output
)
#
output.backward(grad_output)
output
.
backward
(
grad_output
)
#
assert index.data.tolist() == input.grad.data.tolist()
assert
index
.
data
.
tolist
()
==
input
.
grad
.
data
.
tolist
()
torch_scatter/functions/__init__.py
View file @
fdba2556
import
torch
from
torch.autograd
import
Variable
from
.scatter
import
scatter
from
.scatter
import
scatter
from
.utils
import
gen_output
from
.utils
import
gen_output
...
@@ -43,10 +46,13 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
...
@@ -43,10 +46,13 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
output_count
=
output
.
new
(
output
.
size
()).
fill_
(
0
)
if
torch
.
is_tensor
(
input
):
output_count
=
output
.
new
(
output
.
size
()).
fill_
(
0
)
else
:
output_count
=
Variable
(
output
.
data
.
new
(
output
.
size
()).
fill_
(
0
))
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
output_count
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
output_count
)
output_count
[
output_count
==
0
]
=
1
output
/=
output_count
output
/=
output_count
output
[
output
!=
output
]
=
0
return
output
return
output
...
...
torch_scatter/functions/scatter.py
View file @
fdba2556
...
@@ -34,7 +34,10 @@ class _Scatter(Function):
...
@@ -34,7 +34,10 @@ class _Scatter(Function):
if
self
.
needs_input_grad
[
2
]:
if
self
.
needs_input_grad
[
2
]:
grad_input
=
grad
.
gather
(
self
.
dim
,
index
.
data
)
grad_input
=
grad
.
gather
(
self
.
dim
,
index
.
data
)
return
grad_output
,
None
,
grad_input
if
len
(
grad
)
==
3
:
return
grad_output
,
None
,
grad_input
else
:
return
grad_output
,
None
,
grad_input
,
None
def
scatter
(
name
,
dim
,
*
data
):
def
scatter
(
name
,
dim
,
*
data
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment