Commit 50a5dae8 authored by rusty1s's avatar rusty1s
Browse files

scatter min

parent d367c0b5
[
{
"name": "add",
"index": [2, 0, 1, 1, 0],
"input": [1, 2, 3, 4, 5],
"dim": 0,
"fill_value": 0,
"grad": [4, 8, 6],
"expected": [6, 4, 8, 8, 4]
},
{
"name": "sub",
"index": [2, 0, 1, 1, 0],
"input": [1, 2, 3, 4, 5],
"dim": 0,
"fill_value": 0,
"grad": [4, 8, 6],
"expected": [-6, -4, -8, -8, -4]
},
{
"name": "mean",
"index": [2, 0, 1, 1, 0],
"input": [1, 2, 3, 4, 5],
"dim": 0,
"fill_value": 0,
"grad": [4, 8, 6],
"expected": [6, 2, 4, 4, 2]
},
{
"name": "max",
"index": [2, 0, 1, 1, 0],
"input": [1, 2, 3, 4, 5],
"dim": 0,
"fill_value": 0,
"grad": [4, 8, 6],
"expected": [6, 0, 0, 8, 4]
},
{
"name": "min",
"index": [2, 0, 1, 1, 0],
"input": [1, 2, 3, 4, 5],
"dim": 0,
"fill_value": 3,
"grad": [4, 8, 6],
"expected": [6, 4, 8, 0, 0]
},
{
"name": "mul",
"index": [2, 0, 1, 1, 0],
"input": [1, 2, 3, 4, 5],
"dim": 0,
"fill_value": 2,
"grad": [4, 8, 6],
"expected": [12, 40, 64, 48, 16]
}
]
[
{
"name": "add",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 0,
"expected": [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
},
{
"name": "add",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"expected": [[6, 5], [6, 8]]
},
{
"name": "sub",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 9,
"expected": [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]]
},
{
"name": "sub",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 2], [4, 2], [1, 3]],
"dim": 0,
"fill_value": 9,
"expected": [[3, 4], [3, 5]]
},
{
"name": "mul",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 1,
"expected": [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]]
},
{
"name": "mul",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 1,
"expected": [[5, 6], [8, 15]]
},
{
"name": "div",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
"dim": 1,
"fill_value": 1,
"expected": [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]]
},
{
"name": "div",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[4, 2], [2, 1], [4, 2], [1, 2]],
"dim": 0,
"fill_value": 1,
"expected": [[0.25, 0.25], [0.125, 0.5]]
},
{
"name": "mean",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 0,
"expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
},
{
"name": "mean",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"expected": [[3, 2.5], [3, 4]]
},
{
"name": "max",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 0,
"expected": [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
"expected_arg": [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
},
{
"name": "max",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 0,
"expected": [[5, 3], [4, 5]],
"expected_arg": [[0, 3], [2, 1]]
},
{
"name": "min",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 9,
"expected": [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
"expected_arg": [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
},
{
"name": "min",
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
"dim": 0,
"fill_value": 9,
"expected": [[1, 2], [2, 3]],
"expected_arg": [[3, 0], [1, 2]]
}
]
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
import torch_scatter import torch_scatter
from .utils import devices from .utils import dtypes, devices, tensor
funcs = ['add', 'sub', 'mul', 'div', 'mean'] funcs = ['add', 'sub', 'mul', 'div', 'mean']
indices = [2, 0, 1, 1, 0] indices = [2, 0, 1, 1, 0]
...@@ -20,3 +20,35 @@ def test_backward(func, device): ...@@ -20,3 +20,35 @@ def test_backward(func, device):
op = getattr(torch_scatter, 'scatter_{}'.format(func)) op = getattr(torch_scatter, 'scatter_{}'.format(func))
data = (src, index) data = (src, index)
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
tests = [{
'name': 'max',
'src': [1, 2, 3, 4, 5],
'index': [2, 0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'grad': [4, 8, 6],
'expected': [6, 0, 0, 8, 4]
}, {
'name': 'min',
'src': [1, 2, 3, 4, 5],
'index': [2, 0, 1, 1, 0],
'dim': 0,
'fill_value': 3,
'grad': [4, 8, 6],
'expected': [6, 4, 8, 0, 0]
}]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_arg_backward(test, dtype, device):
src = tensor(test['src'], dtype, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
grad = tensor(test['grad'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
out.backward(grad)
assert src.grad.tolist() == test['expected']
...@@ -76,6 +76,38 @@ tests = [{ ...@@ -76,6 +76,38 @@ tests = [{
'dim': 0, 'dim': 0,
'fill_value': 0, 'fill_value': 0,
'expected': [[3, 2.5], [3, 4]] 'expected': [[3, 2.5], [3, 4]]
}, {
'name': 'max',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
}, {
'name': 'max',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
'dim': 0,
'fill_value': 0,
'expected': [[5, 3], [4, 5]],
'expected_arg': [[0, 3], [2, 1]]
}, {
'name': 'min',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 9,
'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
}, {
'name': 'min',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [[0, 0], [1, 1], [1, 1], [0, 0]],
'dim': 0,
'fill_value': 9,
'expected': [[1, 2], [2, 3]],
'expected_arg': [[3, 0], [1, 2]]
}] }]
...@@ -86,6 +118,10 @@ def test_forward(test, dtype, device): ...@@ -86,6 +118,10 @@ def test_forward(test, dtype, device):
expected = tensor(test['expected'], dtype, device) expected = tensor(test['expected'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name'])) op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
output = op(src, index, test['dim'], fill_value=test['fill_value']) out = op(src, index, test['dim'], fill_value=test['fill_value'])
assert output.tolist() == expected.tolist() if isinstance(out, tuple):
assert out[0].tolist() == expected.tolist()
assert out[1].tolist() == test['expected_arg']
else:
assert out.tolist() == expected.tolist()
...@@ -3,10 +3,12 @@ from .sub import scatter_sub ...@@ -3,10 +3,12 @@ from .sub import scatter_sub
from .mul import scatter_mul from .mul import scatter_mul
from .div import scatter_div from .div import scatter_div
from .mean import scatter_mean from .mean import scatter_mean
from .max import scatter_max
from .min import scatter_min
__version__ = '1.0.0' __version__ = '1.0.0'
__all__ = [ __all__ = [
'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean', 'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_div', 'scatter_mean',
'__version__' 'scatter_max', 'scatter_min', '__version__'
] ]
from .scatter import Scatter, scatter
from .utils import gen_output
class ScatterDiv(Scatter): # pragma: no cover
def __init__(self, dim):
super(ScatterDiv, self).__init__('div', dim)
def save_for_backward_step(self, *data):
output, index, input = data
self.save_for_backward(output, index, input)
def backward_step(self, *data):
grad, output, index, input = data
return (output.data / grad).gather(self.dim, index.data) * input.data
def scatter_div_(output, index, input, dim=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/div.svg?sanitize=true
:align: center
:width: 400px
|
Divides all values from the :attr:`input` tensor into :attr:`output` at
the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions divide** (`cf.` :meth:`~torch_scatter.scatter_add_`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{output}_i \cdot \prod_j
\frac{1}{\mathrm{input}_j}
where prod is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_div_
input = torch.Tensor([[2, 1, 2, 4, 3], [1, 2, 2, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.ones(2, 6)
scatter_div_(output, index, input, dim=1)
print(output)
.. testoutput::
1.0000 1.0000 0.2500 0.3333 0.2500 1.0000
0.5000 0.2500 0.1667 1.0000 1.0000 1.0000
[torch.FloatTensor of size 2x6]
"""
return scatter(ScatterDiv, 'div', dim, output, index, input)
def scatter_div(index, input, dim=0, size=None, fill_value=1):
r"""Divides all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`
(`cf.` :meth:`~torch_scatter.scatter_div_` and
:meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{fill\_value} \cdot \prod_j
\frac{1}{\mathrm{input}_j}
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_div
input = torch.Tensor([[2, 1, 2, 4, 3], [1, 2, 2, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_div(index, input, dim=1)
print(output)
.. testoutput::
1.0000 1.0000 0.2500 0.3333 0.2500 1.0000
0.5000 0.2500 0.1667 1.0000 1.0000 1.0000
[torch.FloatTensor of size 2x6]
"""
output = gen_output(index, input, dim, size, fill_value)
return scatter_div_(output, index, input, dim)
from itertools import chain
from .._ext import ffi
def scatter(name, dim, *data):
# data = output, index, input, additional data
a, b, c = data[:3]
# Assert index dimension is valid.
assert dim >= 0 and dim < b.dim(), 'Index dimension is out of bounds'
# Assert same dimensionality across all inputs.
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
'input tensor')
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
'output tensor')
# Assert same tensor length across index and input.
assert b.numel() == c.numel(), ('Index tensor must have same size as '
'input tensor')
# Assert same tensor sizes across input and output apart from `dim`.
for d in chain(range(dim), range(dim + 1, a.dim())):
assert a.size(d) == c.size(d), (
'Input tensor must have same size as output tensor apart from the '
'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '')
cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data)
if len(data) <= 3:
return data[0]
return (data[0], ) + tuple(data[3:])
def index_backward(dim, index, grad, arg): # pragma: no cover
typename = type(grad).__name__.replace('Tensor', '')
cuda = 'cuda_' if grad.is_cuda else ''
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, arg)
return output
from .scatter import Scatter, scatter
from .ffi import index_backward
from .utils import gen_filled_tensor, gen_output
class ScatterMax(Scatter):
def __init__(self, dim):
super(ScatterMax, self).__init__('max', dim)
def save_for_backward_step(self, *data):
output, index, input, arg = data
self.save_for_backward(index, arg)
def backward_step(self, *data): # pragma: no cover
grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data)
def scatter_max_(output, index, input, dim=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/max.svg?sanitize=true
:align: center
:width: 400px
|
Maximizes all values from the :attr:`input` tensor into :attr:`output` at
the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions maximize** (`cf.` :meth:`~torch_scatter.scatter_add_`).
The second return value is the index location in :attr:`input` of each
maximum value found (argmax).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \max(\mathrm{output}_i, \max_j(\mathrm{input}_j))
where max is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_max_
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
output = scatter_max_(output, index, input, dim=1)
print(output)
.. testoutput::
(
0 0 4 3 2 0
2 4 3 0 0 0
[torch.FloatTensor of size 2x6]
,
-1 -1 3 4 0 1
1 4 3 -1 -1 -1
[torch.LongTensor of size 2x6]
)
"""
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter(ScatterMax, 'max', dim, output, index, input, arg)
def scatter_max(index, input, dim=0, size=None, fill_value=0):
r"""Maximizes all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`
(`cf.` :meth:`~torch_scatter.scatter_max_` and
:meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \max(\mathrm{fill\_value},
\max_j(\mathrm{input}_j))
where max is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_max
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_max(index, input, dim=1)
print(output)
.. testoutput::
(
0 0 4 3 2 0
2 4 3 0 0 0
[torch.FloatTensor of size 2x6]
,
-1 -1 3 4 0 1
1 4 3 -1 -1 -1
[torch.LongTensor of size 2x6]
)
"""
output = gen_output(index, input, dim, size, fill_value)
return scatter_max_(output, index, input, dim)
from __future__ import division
from .scatter import Scatter, scatter
from .utils import gen_filled_tensor, gen_output
class ScatterMean(Scatter):
def __init__(self, dim):
super(ScatterMean, self).__init__('mean', dim)
def save_for_backward_step(self, *data):
output, index, input, count = data
self.save_for_backward(index)
def backward_step(self, *data): # pragma: no cover
grad, index = data
return grad.gather(self.dim, index.data)
def scatter_mean_(output, index, input, dim=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mean.svg?sanitize=true
:align: center
:width: 400px
|
Averages all values from the :attr:`input` tensor into :attr:`output` at
the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add_`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{output}_i + \frac{1}{N_i} \cdot
\sum_j \mathrm{input}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i` and
:math:`N_i` indicates the number of indices referencing :math:`i`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean_
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
scatter_mean_(output, index, input, dim=1)
print(output)
.. testoutput::
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
init = gen_filled_tensor(output, output.size(), fill_value=0)
count = gen_filled_tensor(output, output.size(), fill_value=0)
scatter(ScatterMean, 'mean', dim, init, index, input, count)
count[count == 0] = 1
init /= count
output += init
return output
def scatter_mean(index, input, dim=0, size=None, fill_value=0):
r"""Averages all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`
(`cf.` :meth:`~torch_scatter.scatter_mean_` and
:meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{fill\_value} + \frac{1}{N_i} \cdot
\sum_j \mathrm{input}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i` and
:math:`N_i` indicates the number of indices referencing :math:`i`.
Args:
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean
input = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_mean(index, input, dim=1)
print(output)
.. testoutput::
0.0000 0.0000 4.0000 3.0000 1.5000 0.0000
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
output = gen_output(index, input, dim, size, fill_value)
return scatter_mean_(output, index, input, dim)
from .scatter import Scatter, scatter
from .ffi import index_backward
from .utils import gen_filled_tensor, gen_output
class ScatterMin(Scatter):
def __init__(self, dim):
super(ScatterMin, self).__init__('min', dim)
def save_for_backward_step(self, *data):
output, index, input, arg = data
self.save_for_backward(index, arg)
def backward_step(self, *data): # pragma: no cover
grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data)
def scatter_min_(output, index, input, dim=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/min.svg?sanitize=true
:align: center
:width: 400px
|
Minimizes all values from the :attr:`input` tensor into :attr:`output` at
the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions minimize** (`cf.` :meth:`~torch_scatter.scatter_add_`).
The second return value is the index location in :attr:`input` of each
minimum value found (argmin).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \min(\mathrm{output}_i, \min_j(\mathrm{input}_j))
where min is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_min_
input = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.LongTensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.zeros(2, 6)
output = scatter_min_(output, index, input, dim=1)
print(output)
.. testoutput::
(
0 0 -4 -3 -2 0
-2 -4 -3 0 0 0
[torch.FloatTensor of size 2x6]
,
-1 -1 3 4 0 1
1 4 3 -1 -1 -1
[torch.LongTensor of size 2x6]
)
"""
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter(ScatterMin, 'min', dim, output, index, input, arg)
def scatter_min(index, input, dim=0, size=None, fill_value=0):
r"""Minimizes all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`
(`cf.` :meth:`~torch_scatter.scatter_min_` and
:meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \min(\mathrm{fill\_value},
\min_j(\mathrm{input}_j))
where min is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_min
input = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.LongTensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_min(index, input, dim=1)
print(output)
.. testoutput::
(
0 0 -4 -3 -2 0
-2 -4 -3 0 0 0
[torch.FloatTensor of size 2x6]
,
-1 -1 3 4 0 1
1 4 3 -1 -1 -1
[torch.LongTensor of size 2x6]
)
"""
output = gen_output(index, input, dim, size, fill_value)
return scatter_min_(output, index, input, dim)
from .scatter import Scatter, scatter
from .utils import gen_output
class ScatterMul(Scatter):
def __init__(self, dim):
super(ScatterMul, self).__init__('mul', dim)
def save_for_backward_step(self, *data):
output, index, input = data
self.save_for_backward(output, index, input)
def backward_step(self, *data): # pragma: no cover
grad, output, index, input = data
return (grad * output.data).gather(self.dim, index.data) / input.data
def scatter_mul_(output, index, input, dim=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mul.svg?sanitize=true
:align: center
:width: 400px
|
Multiplies all values from the :attr:`input` tensor into :attr:`output` at
the indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`. If multiple indices reference the same location, their
**contributions multiply** (`cf.` :meth:`~torch_scatter.scatter_add_`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{output}_i \cdot \prod_j \mathrm{input}_j
where sum is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
output (Tensor): The destination tensor
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mul_
input = torch.Tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = torch.ones(2, 6)
scatter_mul_(output, index, input, dim=1)
print(output)
.. testoutput::
1 1 4 3 6 0
6 4 8 1 1 1
[torch.FloatTensor of size 2x6]
"""
return scatter(ScatterMul, 'mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, size=None, fill_value=1):
r"""Multiplies all values from the :attr:`input` tensor at the indices
specified in the :attr:`index` tensor along an given axis :attr:`dim`
(`cf.` :meth:`~torch_scatter.scatter_mul_` and
:meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{output}_i = \mathrm{fill\_value} \cdot \prod_j \mathrm{input}_j
where prod is over :math:`j` such that :math:`\mathrm{index}_j = i`.
Args:
index (LongTensor): The indices of elements to scatter
input (Tensor): The source tensor
dim (int, optional): The axis along which to index
size (int, optional): Output size at dimension :attr:`dim`
fill_value (int, optional): Initial filling of output tensor
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mul
input = torch.Tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
index = torch.LongTensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
output = scatter_mul(index, input, dim=1)
print(output)
.. testoutput::
1 1 4 3 6 0
6 4 8 1 1 1
[torch.FloatTensor of size 2x6]
"""
output = gen_output(index, input, dim, size, fill_value)
return scatter_mul_(output, index, input, dim)
import torch
from torch.autograd import Function
from .ffi import scatter as ffi_scatter
class Scatter(Function):
def __init__(self, name, dim):
super(Scatter, self).__init__()
self.name = name
self.dim = dim
def save_for_backward_step(self, *data): # pragma: no cover
raise NotImplementedError
def forward(self, *data):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(data[0]) # Mark output as dirty.
self.len = len(data) # Save number of arguments for backward step.
output = ffi_scatter(self.name, self.dim, *data)
self.save_for_backward_step(*data)
return output
def backward(self, *data): # pragma: no cover
grad_output = grad_input = None
if self.needs_input_grad[0]:
grad_output = data[0]
# Call grad computation of `input` for the specific scatter operation.
if self.needs_input_grad[2]:
grad_input = self.backward_step(data[0], *self.saved_variables)
# Return and fill with empty grads for non-differentiable arguments.
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
def backward_step(self, *data): # pragma: no cover
raise NotImplementedError
def scatter(Clx, name, dim, *data):
if torch.is_tensor(data[0]):
return ffi_scatter(name, dim, *data)
else:
return Clx(dim)(*data)
import torch
from torch.autograd import Variable
def gen_filled_tensor(input, size, fill_value):
if torch.is_tensor(input):
return input.new(size).fill_(fill_value)
else:
return Variable(input.data.new(size).fill_(fill_value))
def gen_output(index, input, dim, dim_size, fill_value):
if dim_size is None:
dim_size = index.max() + 1
dim_size = dim_size if torch.is_tensor(input) else dim_size.data[0]
size = list(index.size())
size[dim] = dim_size
return gen_filled_tensor(input, torch.Size(size), fill_value)
from torch.autograd import Function
from .utils.ffi import get_func
from .utils.gen import gen
class ScatterMax(Function):
@staticmethod
def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1)
func = get_func('scatter_max', src)
func(dim, out, index, src, arg)
ctx.mark_dirty(out)
ctx.dim = dim
ctx.save_for_backward(index, arg)
return out, arg
@staticmethod
def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_variables
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size())
func = get_func('index_backward', grad_out)
func(ctx.dim, grad_src, index, grad_out, arg)
return None, grad_src, None, None
def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/max.svg?sanitize=true
:align: center
:width: 400px
|
Maximizes all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions maximize** (`cf.` :meth:`~torch_scatter.scatter_add`).
The second return tensor contains index location in :attr:`src` of each
maximum value (known as argmax).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \max(\mathrm{out}_i, \max_j(\mathrm{src}_j))
where :math:`\max` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_max(src, index, out=out)
print(out)
.. testoutput::
(
0 0 4 3 2 0
2 4 3 0 0 0
[torch.FloatTensor of size 2x6]
,
-1 -1 3 4 0 1
1 4 3 -1 -1 -1
[torch.LongTensor of size 2x6]
)
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return ScatterMax.apply(out, src, index, dim)
from torch.autograd import Function
from .utils.ffi import get_func
from .utils.gen import gen
class ScatterMin(Function):
@staticmethod
def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1)
func = get_func('scatter_min', src)
func(dim, out, index, src, arg)
ctx.mark_dirty(out)
ctx.dim = dim
ctx.save_for_backward(index, arg)
return out, arg
@staticmethod
def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_variables
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = grad_out.new_zeros(index.size())
func = get_func('index_backward', grad_out)
func(ctx.dim, grad_src, index, grad_out, arg)
return None, grad_src, None, None
def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/min.svg?sanitize=true
:align: center
:width: 400px
|
Minimizes all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along an given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions maximize** (`cf.` :meth:`~torch_scatter.scatter_add`).
The second return tensor contains index location in :attr:`src` of each
minimum value (known as argmax).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \min(\mathrm{out}_i, \min_j(\mathrm{src}_j))
where :math:`\min` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean
src = torch.tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.tensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_min(src, index, out=out)
print(out)
.. testoutput::
(
0 0 -4 -3 -2 0
-2 -4 -3 0 0 0
[torch.FloatTensor of size 2x6]
,
-1 -1 3 4 0 1
1 4 3 -1 -1 -1
[torch.LongTensor of size 2x6]
)
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return ScatterMin.apply(out, src, index, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment