Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
4b654c29
Commit
4b654c29
authored
Apr 28, 2018
by
rusty1s
Browse files
scatter mul
parent
0b71aadc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
5 deletions
+109
-5
test/test_backward.py
test/test_backward.py
+1
-1
test/test_forward.py
test/test_forward.py
+15
-1
torch_scatter/__init__.py
torch_scatter/__init__.py
+4
-1
torch_scatter/mean.py
torch_scatter/mean.py
+1
-2
torch_scatter/mul.py
torch_scatter/mul.py
+88
-0
No files found.
test/test_backward.py
View file @
4b654c29
...
@@ -7,7 +7,7 @@ import torch_scatter
...
@@ -7,7 +7,7 @@ import torch_scatter
from
.utils
import
devices
from
.utils
import
devices
funcs
=
[
'add'
,
'sub'
,
'mean'
]
funcs
=
[
'add'
,
'sub'
,
'mul'
,
'mean'
]
indices
=
[
2
,
0
,
1
,
1
,
0
]
indices
=
[
2
,
0
,
1
,
1
,
0
]
...
...
test/test_forward.py
View file @
4b654c29
...
@@ -34,11 +34,25 @@ tests = [{
...
@@ -34,11 +34,25 @@ tests = [{
'dim'
:
0
,
'dim'
:
0
,
'fill_value'
:
9
,
'fill_value'
:
9
,
'expected'
:
[[
3
,
4
],
[
3
,
5
]]
'expected'
:
[[
3
,
4
],
[
3
,
5
]]
},
{
'name'
:
'mul'
,
'src'
:
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]],
'index'
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
'dim'
:
-
1
,
'fill_value'
:
1
,
'expected'
:
[[
1
,
1
,
4
,
3
,
2
,
0
],
[
0
,
4
,
3
,
1
,
1
,
1
]]
},
{
'name'
:
'mul'
,
'src'
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
'index'
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
'dim'
:
0
,
'fill_value'
:
1
,
'expected'
:
[[
5
,
6
],
[
8
,
15
]]
},
{
},
{
'name'
:
'mean'
,
'name'
:
'mean'
,
'src'
:
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]],
'src'
:
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]],
'index'
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
'index'
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
'dim'
:
1
,
'dim'
:
-
1
,
'fill_value'
:
0
,
'fill_value'
:
0
,
'expected'
:
[[
0
,
0
,
4
,
3
,
1.5
,
0
],
[
1
,
4
,
2
,
0
,
0
,
0
]]
'expected'
:
[[
0
,
0
,
4
,
3
,
1.5
,
0
],
[
1
,
4
,
2
,
0
,
0
,
0
]]
},
{
},
{
...
...
torch_scatter/__init__.py
View file @
4b654c29
from
.add
import
scatter_add
from
.add
import
scatter_add
from
.sub
import
scatter_sub
from
.sub
import
scatter_sub
from
.mul
import
scatter_mul
from
.mean
import
scatter_mean
from
.mean
import
scatter_mean
__version__
=
'1.0.0'
__version__
=
'1.0.0'
__all__
=
[
'scatter_add'
,
'scatter_sub'
,
'scatter_mean'
,
'__version__'
]
__all__
=
[
'scatter_add'
,
'scatter_sub'
,
'scatter_mul'
,
'scatter_mean'
,
'__version__'
]
torch_scatter/mean.py
View file @
4b654c29
...
@@ -7,14 +7,13 @@ from .utils.gen import gen
...
@@ -7,14 +7,13 @@ from .utils.gen import gen
class
ScatterMean
(
Function
):
class
ScatterMean
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
ctx
.
mark_dirty
(
out
)
count
=
src
.
new_zeros
(
out
.
size
())
count
=
src
.
new_zeros
(
out
.
size
())
func
=
get_func
(
'scatter_mean'
,
src
)
func
=
get_func
(
'scatter_mean'
,
src
)
func
(
dim
,
out
,
index
,
src
,
count
)
func
(
dim
,
out
,
index
,
src
,
count
)
count
[
count
==
0
]
=
1
count
[
count
==
0
]
=
1
out
/=
count
out
/=
count
ctx
.
mark_dirty
(
out
)
ctx
.
save_for_backward
(
index
,
count
)
ctx
.
save_for_backward
(
index
,
count
)
return
out
return
out
...
...
torch_scatter/mul.py
0 → 100644
View file @
4b654c29
from
torch.autograd
import
Function
from
.utils.ffi
import
get_func
from
.utils.gen
import
gen
class
ScatterMul
(
Function
):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
func
=
get_func
(
'scatter_mul'
,
src
)
func
(
dim
,
out
,
index
,
src
)
ctx
.
dim
=
dim
ctx
.
mark_dirty
(
out
)
ctx
.
save_for_backward
(
out
,
src
,
index
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_out
):
out
,
src
,
index
=
ctx
.
saved_variables
grad_src
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
(
grad_out
*
out
)[
index
]
/
src
return
None
,
grad_src
,
None
,
None
def
scatter_mul
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
1
):
r
"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mul.svg?sanitize=true
:align: center
:width: 400px
|
Multiplies all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions multiply** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j \mathrm{src}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_mean(src, index, out=out)
print(out)
.. testoutput::
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
src
,
out
,
index
,
dim
=
gen
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
return
ScatterMul
.
apply
(
out
,
src
,
index
,
dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment