Commit 7fd9091c authored by rusty1s's avatar rusty1s
Browse files

update code and tests

parent 5be6d63a
...@@ -3,10 +3,5 @@ source=torch_scatter ...@@ -3,10 +3,5 @@ source=torch_scatter
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
cuda torch.jit.script
forward
backward
apply
raise raise
min_value
max_value
import torch
from torch_scatter import scatter_logsumexp
def test_logsumexp():
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
out = scatter_logsumexp(src, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1)
out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1)
out4 = torch.tensor(-1, dtype=torch.float)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
from itertools import product
import pytest
import torch import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax from torch_scatter import scatter_log_softmax, scatter_softmax
from test.utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) def test_softmax():
def test_softmax(dtype, device): src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_softmax(src, index) out = scatter_softmax(src, index)
out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1) out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype), out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)
dim=-1)
expected = torch.stack([ expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0).to(device) ], dim=0)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) def test_log_softmax():
def test_softmax_broadcasting(dtype, device): src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
src = torch.randn(10, 5, dtype=dtype, device=device) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
out = out.sum(dim=1)
assert torch.allclose(out, torch.ones_like(out))
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_log_softmax(src, index) out = scatter_log_softmax(src, index)
out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1) out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype), out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1)
dim=-1)
expected = torch.stack([ expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0).to(device) ], dim=0)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
import torch
from torch_scatter import scatter_std
def test_std():
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)
out = scatter_std(src, index, dim=-1, unbiased=True)
std = src.std(dim=-1, unbiased=True)[0]
expected = torch.tensor([[std, 0], [0, std]])
assert torch.allclose(out, expected)
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter
from .utils import grad_dtypes as dtypes, devices, tensor
funcs = ['add', 'sub', 'mul', 'div', 'mean']
indices = [2, 0, 1, 1, 0]
@pytest.mark.parametrize('func,device', product(funcs, devices))
def test_backward(func, device):
index = torch.tensor(indices, dtype=torch.long, device=device)
src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
src.requires_grad_()
op = getattr(torch_scatter, 'scatter_{}'.format(func))
data = (src, index, 0)
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
tests = [{
'name': 'max',
'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
'index': [2, 0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'grad': [[4, 4], [8, 8], [6, 6]],
'expected': [[6, 6], [0, 0], [0, 0], [8, 8], [4, 4]],
}, {
'name': 'min',
'src': [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
'index': [2, 0, 1, 1, 0],
'dim': 0,
'fill_value': 3,
'grad': [[4, 4], [8, 8], [6, 6]],
'expected': [[6, 6], [4, 4], [8, 8], [0, 0], [0, 0]],
}]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_arg_backward(test, dtype, device):
src = tensor(test['src'], dtype, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
grad = tensor(test['grad'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
out.backward(grad)
assert src.grad.tolist() == test['expected']
...@@ -14,16 +14,6 @@ def test_broadcasting(device): ...@@ -14,16 +14,6 @@ def test_broadcasting(device):
out = scatter_add(src, index, dim=2, dim_size=H) out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W) assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, 1, H, W)
src = torch.randn((B, C, H, W), device=device) src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long) index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H) out = scatter_add(src, index, dim=2, dim_size=H)
......
from itertools import product
import pytest
import torch
import torch_scatter
from .utils import dtypes, devices, tensor
tests = [{
'name': 'add',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]],
}, {
'name': 'add',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'expected': [[6, 5], [6, 8]],
}, {
'name': 'sub',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 9,
'expected': [[9, 9, 5, 6, 6, 9], [7, 5, 5, 9, 9, 9]],
}, {
'name': 'sub',
'src': [[5, 2], [2, 2], [4, 2], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 9,
'expected': [[3, 4], [3, 5]],
}, {
'name': 'mul',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 1,
'expected': [[1, 1, 4, 3, 2, 0], [0, 4, 3, 1, 1, 1]],
}, {
'name': 'mul',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 1,
'expected': [[5, 6], [8, 15]],
}, {
'name': 'div',
'src': [[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 1,
'expected': [[1, 1, 0.25, 0.5, 0.5, 1], [0.5, 0.25, 0.5, 1, 1, 1]],
}, {
'name': 'div',
'src': [[4, 2], [2, 1], [4, 2], [1, 2]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 1,
'expected': [[0.25, 0.25], [0.125, 0.5]],
}, {
'name': 'mean',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]],
}, {
'name': 'mean',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'expected': [[3, 2.5], [3, 4]],
}, {
'name': 'max',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 0,
'expected': [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
'expected_arg': [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]],
}, {
'name': 'max',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 0,
'expected': [[5, 3], [4, 5]],
'expected_arg': [[0, 3], [2, 1]],
}, {
'name': 'min',
'src': [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
'index': [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
'dim': -1,
'fill_value': 9,
'expected': [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
'expected_arg': [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]],
}, {
'name': 'min',
'src': [[5, 2], [2, 5], [4, 3], [1, 3]],
'index': [0, 1, 1, 0],
'dim': 0,
'fill_value': 9,
'expected': [[1, 2], [2, 3]],
'expected_arg': [[3, 0], [1, 2]],
}]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out = op(src, index, test['dim'], fill_value=test['fill_value'])
if isinstance(out, tuple):
assert out[0].tolist() == expected.tolist()
assert out[1].tolist() == test['expected_arg']
else:
assert out.tolist() == expected.tolist()
from itertools import product
import torch
import pytest
from torch_scatter import scatter_logsumexp
from .utils import devices, tensor, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)
out = scatter_logsumexp(src, index)
out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0).to(device)
assert torch.allclose(out, expected)
import torch
from torch_scatter import scatter_max, scatter_min
def test_max_fill_value():
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, _ = scatter_max(src, index)
v = torch.finfo(torch.float).min
assert out.tolist() == [[v, v, 4, 3, 2, 0], [2, 4, 3, v, v, v]]
def test_min_fill_value():
src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, _ = scatter_min(src, index)
v = torch.finfo(torch.float).max
assert out.tolist() == [[v, v, -4, -3, -2, 0], [-2, -4, -3, v, v, v]]
from itertools import product
import pytest import pytest
import torch import torch
from torch_scatter import scatter_max import torch_scatter
from .utils import reductions, tensor, dtypes
tests = [
{
'src': [1, 2, 3, 4, 5, 6],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'dim': 0,
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'max': [2, 5, 0, 6],
},
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS') @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS')
def test_multi_gpu(): @pytest.mark.parametrize('test,reduce,dtype', product(tests, reductions,
dtypes))
def test_forward(test, reduce, dtype):
device = torch.device('cuda:1') device = torch.device('cuda:1')
src = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device) src = tensor(test['src'], dtype, device)
index = torch.tensor([0, 0, 1, 1], device=device) index = tensor(test['index'], torch.long, device)
assert scatter_max(src, index)[0].tolist() == [3, 5] indptr = tensor(test['indptr'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = torch_scatter.scatter(src, index, dim, reduce=reduce)
assert torch.all(out == expected)
out = torch_scatter.segment_coo(src, index, reduce=reduce)
assert torch.all(out == expected)
out = torch_scatter.segment_csr(src, indptr, reduce=reduce)
assert torch.all(out == expected)
...@@ -5,9 +5,7 @@ import torch ...@@ -5,9 +5,7 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
import torch_scatter import torch_scatter
from .utils import tensor, dtypes, devices from .utils import reductions, tensor, dtypes, devices
reductions = ['sum', 'add', 'mean', 'min', 'max']
tests = [ tests = [
{ {
......
...@@ -5,9 +5,7 @@ import torch ...@@ -5,9 +5,7 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
import torch_scatter import torch_scatter
from .utils import tensor, dtypes, devices from .utils import reductions, tensor, dtypes, devices
reductions = ['sum', 'add', 'mean', 'min', 'max']
tests = [ tests = [
{ {
......
from itertools import product
import pytest
import torch
from torch_scatter import scatter_std
from .utils import grad_dtypes as dtypes, devices, tensor
biases = [True, False]
@pytest.mark.parametrize('dtype,device,bias', product(dtypes, devices, biases))
def test_std(dtype, device, bias):
src = tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype, device)
index = tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], torch.long, device)
out = scatter_std(src, index, dim=-1, unbiased=bias)
std = src.std(dim=-1, unbiased=bias)[0].item()
expected = tensor([[std, 0], [0, std]], dtype, device)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_empty_std(dtype, device):
out = torch.zeros(1, 5, dtype=dtype, device=device)
src = tensor([], dtype, device).view(0, 5)
index = tensor([], torch.long, device).view(0, 5)
out = scatter_std(src, index, dim=0, out=out)
assert out.tolist() == [[0, 0, 0, 0, 0]]
import torch import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long] dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double] grad_dtypes = [torch.float, torch.double]
......
...@@ -6,6 +6,8 @@ from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, ...@@ -6,6 +6,8 @@ from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo, from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
segment_min_coo, segment_max_coo, segment_coo, segment_min_coo, segment_max_coo, segment_coo,
gather_coo) gather_coo)
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
scatter_log_softmax)
__version__ = '2.0.0' __version__ = '2.0.0'
...@@ -30,6 +32,10 @@ __all__ = [ ...@@ -30,6 +32,10 @@ __all__ = [
'segment_max_coo', 'segment_max_coo',
'segment_coo', 'segment_coo',
'gather_coo', 'gather_coo',
'scatter_std',
'scatter_logsumexp',
'scatter_softmax',
'scatter_log_softmax',
'torch_scatter', 'torch_scatter',
'__version__', '__version__',
] ]
from torch_scatter.utils.gen import gen
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`. If
multiple indices reference the same location, their **contributions add**.
Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
do not match.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_add
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_add(src, index, out=out)
print(out)
.. testoutput::
tensor([[0., 0., 4., 3., 3., 0.],
[2., 4., 4., 0., 0., 0.]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return out.scatter_add_(dim, index, src)
from .std import scatter_std
from .logsumexp import scatter_logsumexp
from .softmax import scatter_log_softmax, scatter_softmax from .softmax import scatter_log_softmax, scatter_softmax
__all__ = [ __all__ = [
'scatter_std',
'scatter_logsumexp',
'scatter_softmax', 'scatter_softmax',
'scatter_log_softmax', 'scatter_log_softmax',
] ]
from typing import Optional
import torch
from torch_scatter import scatter_sum, scatter_max
from .utils import broadcast
@torch.jit.script
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
index = broadcast(index, src, dim)
if out is not None:
dim_size = out.size(dim)
else:
if dim_size is None:
dim_size = int(index.max().item() + 1)
size = src.size()
size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
if out is not None:
out = out.sub_(max_per_src_element).exp_()
sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out,
dim_size)
return sum_per_index.add_(eps).log_().add_(max_value_per_index)
import torch import torch
from torch_scatter import scatter_add, scatter_max from torch_scatter import scatter_sum, scatter_max
from torch_scatter.utils.gen import broadcast
from .utils import broadcast
def scatter_softmax(src, index, dim=-1, eps=1e-12):
r"""
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
For one-dimensional tensors, the operation computes @torch.jit.script
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
.. math:: eps: float = 1e-12) -> torch.Tensor:
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors ' raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.') 'with floating point data types.')
src, index = broadcast(src, index, dim) index = broadcast(index, src, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_value_per_index = scatter_max(src, index, dim=dim)[0]
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp() recentered_scores_exp = recentered_scores.exp_()
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
normalizing_constants = (sum_per_index + eps).gather(dim, index)
return recentered_scores_exp / normalizing_constants sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
return recentered_scores_exp.div_(normalizing_constants)
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r"""
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For one-dimensional tensors, the operation computes @torch.jit.script
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
.. math:: eps: float = 1e-12) -> torch.Tensor:
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src): if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over ' raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.') 'tensors with floating point data types.')
src, index = broadcast(src, index, dim) index = broadcast(index, src, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_value_per_index = scatter_max(src, index, dim=dim)[0]
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index, sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
dim=dim) normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
return recentered_scores - normalizing_constants return recentered_scores.sub_(normalizing_constants)
from typing import Optional
import torch
from torch_scatter import scatter_sum
from .utils import broadcast
@torch.jit.script
def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True) -> torch.Tensor:
if out is not None:
dim_size = out.size(dim)
if dim < 0:
dim = src.dim() + dim
count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clamp_(1)
mean = tmp.div_(count)
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)
if unbiased:
count.sub_(1).clamp_(1)
out.div_(count).sqrt_()
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment