Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
c648e063
"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "f10b28be876fea9e4abeb5e11bbc05b613fbc606"
Commit
c648e063
authored
Dec 22, 2017
by
rusty1s
Browse files
docs done
parent
1b66c684
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
2 deletions
+82
-2
torch_scatter/functions/mean.py
torch_scatter/functions/mean.py
+82
-2
No files found.
torch_scatter/functions/mean.py
View file @
c648e063
...
@@ -3,8 +3,47 @@ from .utils import gen_filled_tensor, gen_output
...
@@ -3,8 +3,47 @@ from .utils import gen_filled_tensor, gen_output
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
"""If multiple indices reference the same location, their
r
"""Averages all values from the :attr:`input` tensor into :attr:`output`
**contributions average**."""
at the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add_`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{output}_i + \frac{1}{N_j} \cdot
\sum_j \mathrm{input}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i` and
:math:`N_j` indicates the number of indices referecing :math:`j`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean_
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
scatter_mean_(output, index, input, dim=1)
print(output)
.. testoutput::
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
num_output
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
num_output
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
num_output
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
num_output
)
num_output
[
num_output
==
0
]
=
1
num_output
[
num_output
==
0
]
=
1
...
@@ -13,5 +52,46 @@ def scatter_mean_(output, index, input, dim=0):
...
@@ -13,5 +52,46 @@ def scatter_mean_(output, index, input, dim=0):
def
scatter_mean
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
def
scatter_mean
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
r
"""Averages all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`
(`cf.` :meth:`~torch_scatter.scatter_mean_` and
:meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{fill\_value} + \frac{1}{N_j} \cdot
\sum_j \mathrm{input}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i` and
:math:`N_j` indicates the number of indices referecing :math:`j`.
Args:
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_mean(index, input, dim=1)
print(output)
.. testoutput::
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
return
scatter_mean_
(
output
,
index
,
input
,
dim
)
return
scatter_mean_
(
output
,
index
,
input
,
dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment