Unverified Commit c01f9bae authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #105 from rusty1s/traceable

[WIP] tracebale functions
parents 2520670a 02a47c46
import torch
from torch_scatter import scatter_max, scatter_min
def test_max_fill_value():
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, _ = scatter_max(src, index)
v = torch.finfo(torch.float).min
assert out.tolist() == [[v, v, 4, 3, 2, 0], [2, 4, 3, v, v, v]]
def test_min_fill_value():
src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, _ = scatter_min(src, index)
v = torch.finfo(torch.float).max
assert out.tolist() == [[v, v, -4, -3, -2, 0], [-2, -4, -3, v, v, v]]
from itertools import product
import pytest
import torch
from torch_scatter import scatter_max
import torch_scatter
from .utils import reductions, tensor, dtypes
tests = [
{
'src': [1, 2, 3, 4, 5, 6],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'dim': 0,
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'max': [2, 5, 0, 6],
},
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS')
def test_multi_gpu():
@pytest.mark.parametrize('test,reduce,dtype', product(tests, reductions,
dtypes))
def test_forward(test, reduce, dtype):
device = torch.device('cuda:1')
src = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device)
index = torch.tensor([0, 0, 1, 1], device=device)
assert scatter_max(src, index)[0].tolist() == [3, 5]
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = torch_scatter.scatter(src, index, dim, reduce=reduce)
assert torch.all(out == expected)
out = torch_scatter.segment_coo(src, index, reduce=reduce)
assert torch.all(out == expected)
out = torch_scatter.segment_csr(src, indptr, reduce=reduce)
assert torch.all(out == expected)
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter
from .utils import reductions, tensor, dtypes, devices
tests = [
{
'src': [1, 3, 2, 4, 5, 6],
'index': [0, 1, 0, 1, 1, 3],
'dim': 0,
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 1, 6, 5],
'max': [2, 5, 0, 6],
'arg_max': [2, 4, 6, 5],
},
{
'src': [[1, 2], [5, 6], [3, 4], [7, 8], [9, 10], [11, 12]],
'index': [0, 1, 0, 1, 1, 3],
'dim': 0,
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]],
'max': [[3, 4], [9, 10], [0, 0], [11, 12]],
'arg_max': [[2, 2], [4, 4], [6, 6], [5, 5]],
},
{
'src': [[1, 5, 3, 7, 9, 11], [2, 4, 8, 6, 10, 12]],
'index': [[0, 1, 0, 1, 1, 3], [0, 0, 1, 0, 1, 2]],
'dim': 1,
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]],
'max': [[3, 9, 0, 11], [6, 10, 12, 0]],
'arg_max': [[2, 4, 6, 5], [3, 4, 5, 6]],
},
{
'src': [[[1, 2], [5, 6], [3, 4]], [[10, 11], [7, 9], [12, 13]]],
'index': [[0, 1, 0], [2, 0, 2]],
'dim': 1,
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]],
'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]],
'arg_max': [[[2, 2], [1, 1], [3, 3]], [[1, 1], [3, 3], [2, 2]]],
},
{
'src': [[1, 3], [2, 4]],
'index': [[0, 0], [0, 0]],
'dim': 1,
'sum': [[4], [6]],
'add': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
'max': [[3], [4]],
'arg_max': [[1], [1]],
},
{
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
'index': [[0, 0], [0, 0]],
'dim': 1,
'sum': [[[4, 4]], [[6, 6]]],
'add': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
'max': [[[3, 3]], [[4, 4]]],
'arg_max': [[[1, 1]], [[1, 1]]],
},
]
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_forward(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
@pytest.mark.parametrize('test,reduce,device',
product(tests, reductions, devices))
def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
dim = test['dim']
assert gradcheck(torch_scatter.scatter,
(src, index, dim, None, None, reduce))
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_out(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = torch.full_like(expected, -2)
getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim, out)
if reduce == 'sum' or reduce == 'add':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
elif reduce == 'min':
expected = expected.fill_(-2)
elif reduce == 'max':
expected[expected == 0] = -2
else:
raise ValueError
assert torch.all(out == expected)
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_non_contiguous(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
......@@ -3,12 +3,9 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import segment_coo, segment_csr
import torch_scatter
from .utils import tensor, dtypes, devices
reductions = ['sum', 'mean', 'min', 'max']
grad_reductions = ['sum', 'mean']
from .utils import reductions, tensor, dtypes, devices
tests = [
{
......@@ -16,6 +13,7 @@ tests = [
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 2, 6, 5],
......@@ -27,6 +25,7 @@ tests = [
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
......@@ -38,6 +37,7 @@ tests = [
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
......@@ -49,6 +49,7 @@ tests = [
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
......@@ -60,6 +61,7 @@ tests = [
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'sum': [[4], [6]],
'add': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
......@@ -71,6 +73,7 @@ tests = [
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'sum': [[[4, 4]], [[6, 6]]],
'add': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
......@@ -88,14 +91,14 @@ def test_forward(test, reduce, dtype, device):
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
out = segment_coo(src, index, reduce=reduce)
out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce)
out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
......@@ -104,37 +107,36 @@ def test_forward(test, reduce, dtype, device):
@pytest.mark.parametrize('test,reduce,device',
product(tests, grad_reductions, devices))
product(tests, reductions, devices))
def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(segment_coo, (src, index, None, None, reduce))
assert gradcheck(segment_csr, (src, indptr, None, reduce))
assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce))
assert gradcheck(torch_scatter.segment_coo,
(src, index, None, None, reduce))
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_segment_out(test, reduce, dtype, device):
def test_out(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
size = list(src.size())
size[indptr.dim() - 1] = indptr.size(-1) - 1
out = src.new_full(size, -2)
out = torch.full_like(expected, -2)
segment_csr(src, indptr, out, reduce=reduce)
getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr, out)
assert torch.all(out == expected)
out.fill_(-2)
segment_coo(src, index, out, reduce=reduce)
getattr(torch_scatter, f'segment_{reduce}_coo')(src, index, out)
if reduce == 'sum':
if reduce == 'sum' or reduce == 'add':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
......@@ -150,7 +152,7 @@ def test_segment_out(test, reduce, dtype, device):
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_non_contiguous_segment(test, reduce, dtype, device):
def test_non_contiguous(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
......@@ -163,14 +165,14 @@ def test_non_contiguous_segment(test, reduce, dtype, device):
if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = segment_coo(src, index, reduce=reduce)
out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce)
out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
......
from itertools import product
import pytest
import torch
from torch_scatter import scatter_std
from .utils import grad_dtypes as dtypes, devices, tensor
biases = [True, False]
@pytest.mark.parametrize('dtype,device,bias', product(dtypes, devices, biases))
def test_std(dtype, device, bias):
src = tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype, device)
index = tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], torch.long, device)
out = scatter_std(src, index, dim=-1, unbiased=bias)
std = src.std(dim=-1, unbiased=bias)[0].item()
expected = tensor([[std, 0], [0, std]], dtype, device)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_empty_std(dtype, device):
out = torch.zeros(1, 5, dtype=dtype, device=device)
src = tensor([], dtype, device).view(0, 5)
index = tensor([], torch.long, device).view(0, 5)
out = scatter_std(src, index, dim=0, out=out)
assert out.tolist() == [[0, 0, 0, 0, 0]]
import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
......
import torch
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
scatter_max, scatter)
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
segment_min_csr, segment_max_csr, segment_csr,
gather_csr)
from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
segment_min_coo, segment_max_coo, segment_coo,
gather_coo)
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
scatter_log_softmax)
from .add import scatter_add
from .sub import scatter_sub
from .mul import scatter_mul
from .div import scatter_div
from .mean import scatter_mean
from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
from .segment import segment_coo, segment_csr
from .gather import gather_coo, gather_csr
import torch_scatter.composite
torch.ops.load_library('torch_scatter/scatter_cpu.so')
torch.ops.load_library('torch_scatter/segment_cpu.so')
torch.ops.load_library('torch_scatter/gather_cpu.so')
try:
torch.ops.load_library('torch_scatter/scatter_cuda.so')
torch.ops.load_library('torch_scatter/segment_cuda.so')
torch.ops.load_library('torch_scatter/gather_cuda.so')
except OSError as e:
if torch.cuda.is_available():
raise e
__version__ = '1.4.0'
__version__ = '2.0.0'
__all__ = [
'scatter_sum',
'scatter_add',
'scatter_sub',
'scatter_mul',
'scatter_div',
'scatter_mean',
'scatter_std',
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'segment_coo',
'scatter_max',
'scatter',
'segment_sum_csr',
'segment_add_csr',
'segment_mean_csr',
'segment_min_csr',
'segment_max_csr',
'segment_csr',
'gather_coo',
'gather_csr',
'segment_sum_coo',
'segment_add_coo',
'segment_mean_coo',
'segment_min_coo',
'segment_max_coo',
'segment_coo',
'gather_coo',
'scatter_std',
'scatter_logsumexp',
'scatter_softmax',
'scatter_log_softmax',
'torch_scatter',
'__version__',
]
from torch_scatter.utils.gen import gen
def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`. If
multiple indices reference the same location, their **contributions add**.
Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
do not match.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_add
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_add(src, index, out=out)
print(out)
.. testoutput::
tensor([[0., 0., 4., 3., 3., 0.],
[2., 4., 4., 0., 0., 0.]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
return out.scatter_add_(dim, index, src)
from .std import scatter_std
from .logsumexp import scatter_logsumexp
from .softmax import scatter_log_softmax, scatter_softmax
__all__ = [
'scatter_std',
'scatter_logsumexp',
'scatter_softmax',
'scatter_log_softmax',
]
from typing import Optional
import torch
from torch_scatter import scatter_sum, scatter_max
from .utils import broadcast
@torch.jit.script
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
index = broadcast(index, src, dim)
if out is not None:
dim_size = out.size(dim)
else:
if dim_size is None:
dim_size = int(index.max().item() + 1)
size = src.size()
size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
if out is not None:
out = out.sub_(max_per_src_element).exp_()
sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out,
dim_size)
return sum_per_index.add_(eps).log_().add_(max_value_per_index)
import torch
from torch_scatter import scatter_add, scatter_max
from torch_scatter.utils.gen import broadcast
from torch_scatter import scatter_sum, scatter_max
from .utils import broadcast
def scatter_softmax(src, index, dim=-1, eps=1e-12):
r"""
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
@torch.jit.script
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
src, index = broadcast(src, index, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
index = broadcast(index, src, dim)
max_value_per_index = scatter_max(src, index, dim=dim)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()
sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
normalizing_constants = (sum_per_index + eps).gather(dim, index)
recentered_scores_exp = recentered_scores.exp_()
return recentered_scores_exp / normalizing_constants
sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
return recentered_scores_exp.div_(normalizing_constants)
def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r"""
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
@torch.jit.script
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
src, index = broadcast(src, index, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
index = broadcast(index, src, dim)
max_value_per_index = scatter_max(src, index, dim=dim)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
dim=dim)
normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)
sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
return recentered_scores - normalizing_constants
return recentered_scores.sub_(normalizing_constants)
from typing import Optional
import torch
from torch_scatter import scatter_sum
from .utils import broadcast
@torch.jit.script
def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True) -> torch.Tensor:
if out is not None:
dim_size = out.size(dim)
if dim < 0:
dim = src.dim() + dim
count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clamp_(1)
mean = tmp.div_(count)
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)
if unbiased:
count.sub_(1).clamp_(1)
out.div_(count).sqrt_()
return out
import torch
@torch.jit.script
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
import torch
from torch_scatter.utils.gen import gen
class ScatterDiv(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_div(src, index, out, dim)
else:
torch.ops.torch_scatter_cpu.scatter_div(src, index, out, dim)
ctx.mark_dirty(out)
ctx.save_for_backward(out, src, index)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, grad_out):
out, src, index = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = -(out * grad_out).gather(ctx.dim, index) / src
return None, grad_src, None, None
def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/div.svg?sanitize=true
:align: center
:width: 400px
|
Divides all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions divide** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j
\frac{1}{\mathrm{src}_j}
where :math:`\prod_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`1`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_div
src = torch.Tensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float()
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))
out = scatter_div(src, index, out=out)
print(out)
.. testoutput::
tensor([[1.0000, 1.0000, 0.2500, 0.5000, 0.5000, 1.0000],
[0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterDiv.apply(out, src, index, dim)
import torch
class GatherCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(index)
if src.is_cuda:
return torch.ops.torch_scatter_cuda.gather_coo(src, index, out)
else:
return torch.ops.torch_scatter_cpu.gather_coo(src, index, out)
@staticmethod
def backward(ctx, grad_out):
(index, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if grad_out.is_cuda:
grad_src, _ = torch.ops.torch_scatter_cuda.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'sum')
else:
grad_src, _ = torch.ops.torch_scatter_cpu.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'sum')
return grad_src, None, None
class GatherCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(indptr)
if src.is_cuda:
return torch.ops.torch_scatter_cuda.gather_csr(src, indptr, out)
else:
return torch.ops.torch_scatter_cpu.gather_csr(src, indptr, out)
@staticmethod
def backward(ctx, grad_out):
(indptr, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if grad_out.is_cuda:
grad_src, _ = torch.ops.torch_scatter_cuda.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
else:
grad_src, _ = torch.ops.torch_scatter_cpu.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
return grad_src, None, None
def gather_coo(src, index, out=None):
return GatherCOO.apply(src, index, out)
def gather_csr(src, indptr, out=None):
return GatherCSR.apply(src, indptr, out)
import torch
def min_value(dtype): # pragma: no cover
try:
return torch.finfo(dtype).min
except TypeError:
return torch.iinfo(dtype).min
def max_value(dtype): # pragma: no cover
try:
return torch.finfo(dtype).max
except TypeError:
return torch.iinfo(dtype).max
import torch
from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, eps=1e-12):
r"""Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
tensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
\exp(\mathrm{src}_j) \right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)
return torch.log(sum_per_index + eps) + max_value_per_index
import torch
from torch_scatter.utils.gen import gen
class ScatterMax(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1)
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_max(src, index, out, arg, dim)
else:
torch.ops.torch_scatter_cpu.scatter_max(src, index, out, arg, dim)
ctx.mark_dirty(out)
ctx.dim = dim
ctx.save_for_backward(index, arg)
return out, arg
@staticmethod
def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None
def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/max.svg?sanitize=true
:align: center
:width: 400px
|
Maximizes all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions maximize** (`cf.` :meth:`~torch_scatter.scatter_add`).
The second return tensor contains index location in :attr:`src` of each
maximum value (known as argmax).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \max(\mathrm{out}_i, \max_j(\mathrm{src}_j))
where :math:`\max_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)
.. testoutput::
tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
"""
if fill_value is None:
op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
fill_value = op(src.dtype).min
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out, index.new_full(out.size(), -1)
return ScatterMax.apply(out, src, index, dim)
import torch
from torch_scatter import scatter_add
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mean.svg?sanitize=true
:align: center
:width: 400px
|
Averages all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \frac{1}{N_i} \cdot
\sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`N_i` indicates the number of indices
referencing :math:`i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_mean(src, index, out=out)
print(out)
.. testoutput::
tensor([[0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000],
[1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]])
"""
out = scatter_add(src, index, dim, out, dim_size, fill_value)
count = scatter_add(torch.ones_like(src), index, dim, None, out.size(dim))
return out / count.clamp(min=1)
import torch
from torch_scatter.utils.gen import gen
class ScatterMin(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1)
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_min(src, index, out, arg, dim)
else:
torch.ops.torch_scatter_cpu.scatter_min(src, index, out, arg, dim)
ctx.mark_dirty(out)
ctx.dim = dim
ctx.save_for_backward(index, arg)
return out, arg
@staticmethod
def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None
def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/min.svg?sanitize=true
:align: center
:width: 400px
|
Minimizes all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions minimize** (`cf.` :meth:`~torch_scatter.scatter_add`).
The second return tensor contains index location in :attr:`src` of each
minimum value (known as argmin).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \min(\mathrm{out}_i, \min_j(\mathrm{src}_j))
where :math:`\min_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the greatest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_min
src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.tensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out, argmin = scatter_min(src, index, out=out)
print(out)
print(argmin)
.. testoutput::
tensor([[ 0., 0., -4., -3., -2., 0.],
[-2., -4., -3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
"""
if fill_value is None:
op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
fill_value = op(src.dtype).max
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out, index.new_full(out.size(), -1)
return ScatterMin.apply(out, src, index, dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment