Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
68f4609c
Commit
68f4609c
authored
Jan 12, 2021
by
rusty1s
Browse files
added scatter_mul back in
parent
12722728
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
4 deletions
+62
-4
csrc/scatter.cpp
csrc/scatter.cpp
+38
-0
test/test_scatter.py
test/test_scatter.py
+10
-0
torch_scatter/__init__.py
torch_scatter/__init__.py
+3
-2
torch_scatter/scatter.py
torch_scatter/scatter.py
+11
-2
No files found.
csrc/scatter.cpp
View file @
68f4609c
...
...
@@ -70,6 +70,37 @@ public:
}
};
class
ScatterMul
:
public
torch
::
autograd
::
Function
<
ScatterMul
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
int64_t
dim
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
dim
=
dim
<
0
?
src
.
dim
()
+
dim
:
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
index
=
broadcast
(
index
,
src
,
dim
);
auto
result
=
scatter_fw
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
"mul"
);
auto
out
=
std
::
get
<
0
>
(
result
);
ctx
->
save_for_backward
({
src
,
index
,
out
});
if
(
optional_out
.
has_value
())
ctx
->
mark_dirty
({
optional_out
.
value
()});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
src
=
saved
[
0
];
auto
index
=
saved
[
1
];
auto
out
=
saved
[
2
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
auto
grad_in
=
torch
::
gather
(
grad_out
*
out
,
dim
,
index
,
false
).
div_
(
src
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
};
class
ScatterMean
:
public
torch
::
autograd
::
Function
<
ScatterMean
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
...
...
@@ -197,6 +228,12 @@ torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
return
ScatterSum
::
apply
(
src
,
index
,
dim
,
optional_out
,
dim_size
)[
0
];
}
torch
::
Tensor
scatter_mul
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
return
ScatterMul
::
apply
(
src
,
index
,
dim
,
optional_out
,
dim_size
)[
0
];
}
torch
::
Tensor
scatter_mean
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
...
...
@@ -221,6 +258,7 @@ scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_scatter::scatter_sum"
,
&
scatter_sum
)
.
op
(
"torch_scatter::scatter_mul"
,
&
scatter_mul
)
.
op
(
"torch_scatter::scatter_mean"
,
&
scatter_mean
)
.
op
(
"torch_scatter::scatter_min"
,
&
scatter_min
)
.
op
(
"torch_scatter::scatter_max"
,
&
scatter_max
);
test/test_scatter.py
View file @
68f4609c
...
...
@@ -7,6 +7,8 @@ import torch_scatter
from
.utils
import
reductions
,
tensor
,
dtypes
,
devices
reductions
=
reductions
+
[
'mul'
]
tests
=
[
{
'src'
:
[
1
,
3
,
2
,
4
,
5
,
6
],
...
...
@@ -14,6 +16,7 @@ tests = [
'dim'
:
0
,
'sum'
:
[
3
,
12
,
0
,
6
],
'add'
:
[
3
,
12
,
0
,
6
],
'mul'
:
[
2
,
60
,
1
,
6
],
'mean'
:
[
1.5
,
4
,
0
,
6
],
'min'
:
[
1
,
3
,
0
,
6
],
'arg_min'
:
[
0
,
1
,
6
,
5
],
...
...
@@ -26,6 +29,7 @@ tests = [
'dim'
:
0
,
'sum'
:
[[
4
,
6
],
[
21
,
24
],
[
0
,
0
],
[
11
,
12
]],
'add'
:
[[
4
,
6
],
[
21
,
24
],
[
0
,
0
],
[
11
,
12
]],
'mul'
:
[[
1
*
3
,
2
*
4
],
[
5
*
7
*
9
,
6
*
8
*
10
],
[
1
,
1
],
[
11
,
12
]],
'mean'
:
[[
2
,
3
],
[
7
,
8
],
[
0
,
0
],
[
11
,
12
]],
'min'
:
[[
1
,
2
],
[
5
,
6
],
[
0
,
0
],
[
11
,
12
]],
'arg_min'
:
[[
0
,
0
],
[
1
,
1
],
[
6
,
6
],
[
5
,
5
]],
...
...
@@ -38,6 +42,7 @@ tests = [
'dim'
:
1
,
'sum'
:
[[
4
,
21
,
0
,
11
],
[
12
,
18
,
12
,
0
]],
'add'
:
[[
4
,
21
,
0
,
11
],
[
12
,
18
,
12
,
0
]],
'mul'
:
[[
1
*
3
,
5
*
7
*
9
,
1
,
11
],
[
2
*
4
*
6
,
8
*
10
,
12
,
1
]],
'mean'
:
[[
2
,
7
,
0
,
11
],
[
4
,
9
,
12
,
0
]],
'min'
:
[[
1
,
5
,
0
,
11
],
[
2
,
8
,
12
,
0
]],
'arg_min'
:
[[
0
,
1
,
6
,
5
],
[
0
,
2
,
5
,
6
]],
...
...
@@ -50,6 +55,7 @@ tests = [
'dim'
:
1
,
'sum'
:
[[[
4
,
6
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
22
,
24
]]],
'add'
:
[[[
4
,
6
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
22
,
24
]]],
'mul'
:
[[[
3
,
8
],
[
5
,
6
],
[
1
,
1
]],
[[
7
,
9
],
[
1
,
1
],
[
120
,
11
*
13
]]],
'mean'
:
[[[
2
,
3
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
11
,
12
]]],
'min'
:
[[[
1
,
2
],
[
5
,
6
],
[
0
,
0
]],
[[
7
,
9
],
[
0
,
0
],
[
10
,
11
]]],
'arg_min'
:
[[[
0
,
0
],
[
1
,
1
],
[
3
,
3
]],
[[
1
,
1
],
[
3
,
3
],
[
0
,
0
]]],
...
...
@@ -62,6 +68,7 @@ tests = [
'dim'
:
1
,
'sum'
:
[[
4
],
[
6
]],
'add'
:
[[
4
],
[
6
]],
'mul'
:
[[
3
],
[
8
]],
'mean'
:
[[
2
],
[
3
]],
'min'
:
[[
1
],
[
2
]],
'arg_min'
:
[[
0
],
[
0
]],
...
...
@@ -74,6 +81,7 @@ tests = [
'dim'
:
1
,
'sum'
:
[[[
4
,
4
]],
[[
6
,
6
]]],
'add'
:
[[[
4
,
4
]],
[[
6
,
6
]]],
'mul'
:
[[[
3
,
3
]],
[[
8
,
8
]]],
'mean'
:
[[[
2
,
2
]],
[[
3
,
3
]]],
'min'
:
[[[
1
,
1
]],
[[
2
,
2
]]],
'arg_min'
:
[[[
0
,
0
]],
[[
0
,
0
]]],
...
...
@@ -125,6 +133,8 @@ def test_out(test, reduce, dtype, device):
if
reduce
==
'sum'
or
reduce
==
'add'
:
expected
=
expected
-
2
elif
reduce
==
'mul'
:
expected
=
out
# We can not really test this here.
elif
reduce
==
'mean'
:
expected
=
out
# We can not really test this here.
elif
reduce
==
'min'
:
...
...
torch_scatter/__init__.py
View file @
68f4609c
...
...
@@ -58,8 +58,8 @@ if torch.cuda.is_available() and torch.version.cuda: # pragma: no cover
f
'
{
major
}
.
{
minor
}
. Please reinstall the torch_scatter that '
f
'matches your PyTorch install.'
)
from
.scatter
import
(
scatter_sum
,
scatter_add
,
scatter_m
ean
,
scatter_m
i
n
,
scatter_max
,
scatter
)
# noqa
from
.scatter
import
(
scatter_sum
,
scatter_add
,
scatter_m
ul
,
scatter_m
ea
n
,
scatter_min
,
scatter_max
,
scatter
)
# noqa
from
.segment_csr
import
(
segment_sum_csr
,
segment_add_csr
,
segment_mean_csr
,
segment_min_csr
,
segment_max_csr
,
segment_csr
,
gather_csr
)
# noqa
...
...
@@ -72,6 +72,7 @@ from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
__all__
=
[
'scatter_sum'
,
'scatter_add'
,
'scatter_mul'
,
'scatter_mean'
,
'scatter_min'
,
'scatter_max'
,
...
...
torch_scatter/scatter.py
View file @
68f4609c
...
...
@@ -31,6 +31,13 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
scatter_mul
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
scatter_mul
(
src
,
index
,
dim
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
scatter_mean
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -127,8 +134,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"m
ean
"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"m
ul
"`,
:obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
...
...
@@ -150,6 +157,8 @@ def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
"""
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
if
reduce
==
'mul'
:
return
scatter_mul
(
src
,
index
,
dim
,
out
,
dim_size
)
elif
reduce
==
'mean'
:
return
scatter_mean
(
src
,
index
,
dim
,
out
,
dim_size
)
elif
reduce
==
'min'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment