Commit 7fd9091c authored by rusty1s's avatar rusty1s
Browse files

update code and tests

parent 5be6d63a
import torch
@torch.jit.script
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
import torch
from torch_scatter.utils.gen import gen
class ScatterDiv(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_div(src, index, out, dim)
else:
torch.ops.torch_scatter_cpu.scatter_div(src, index, out, dim)
ctx.mark_dirty(out)
ctx.save_for_backward(out, src, index)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, grad_out):
out, src, index = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = -(out * grad_out).gather(ctx.dim, index) / src
return None, grad_src, None, None
def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/div.svg?sanitize=true
:align: center
:width: 400px
|
Divides all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions divide** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j
\frac{1}{\mathrm{src}_j}
where :math:`\prod_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`1`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_div
src = torch.Tensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float()
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))
out = scatter_div(src, index, out=out)
print(out)
.. testoutput::
tensor([[1.0000, 1.0000, 0.2500, 0.5000, 0.5000, 1.0000],
[0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterDiv.apply(out, src, index, dim)
import torch
class GatherCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(index)
if src.is_cuda:
return torch.ops.torch_scatter_cuda.gather_coo(src, index, out)
else:
return torch.ops.torch_scatter_cpu.gather_coo(src, index, out)
@staticmethod
def backward(ctx, grad_out):
(index, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if grad_out.is_cuda:
grad_src, _ = torch.ops.torch_scatter_cuda.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'sum')
else:
grad_src, _ = torch.ops.torch_scatter_cpu.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'sum')
return grad_src, None, None
class GatherCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(indptr)
if src.is_cuda:
return torch.ops.torch_scatter_cuda.gather_csr(src, indptr, out)
else:
return torch.ops.torch_scatter_cpu.gather_csr(src, indptr, out)
@staticmethod
def backward(ctx, grad_out):
(indptr, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if grad_out.is_cuda:
grad_src, _ = torch.ops.torch_scatter_cuda.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
else:
grad_src, _ = torch.ops.torch_scatter_cpu.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
return grad_src, None, None
def gather_coo(src, index, out=None):
return GatherCOO.apply(src, index, out)
def gather_csr(src, indptr, out=None):
return GatherCSR.apply(src, indptr, out)
import torch
def min_value(dtype): # pragma: no cover
try:
return torch.finfo(dtype).min
except TypeError:
return torch.iinfo(dtype).min
def max_value(dtype): # pragma: no cover
try:
return torch.finfo(dtype).max
except TypeError:
return torch.iinfo(dtype).max
import torch
from . import scatter_add, scatter_max
def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, eps=1e-12):
r"""Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
tensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
\exp(\mathrm{src}_j) \right)
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)
:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None
sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)
return torch.log(sum_per_index + eps) + max_value_per_index
import torch
from torch_scatter.utils.gen import gen
class ScatterMax(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1)
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_max(src, index, out, arg, dim)
else:
torch.ops.torch_scatter_cpu.scatter_max(src, index, out, arg, dim)
ctx.mark_dirty(out)
ctx.dim = dim
ctx.save_for_backward(index, arg)
return out, arg
@staticmethod
def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None
def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/max.svg?sanitize=true
:align: center
:width: 400px
|
Maximizes all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions maximize** (`cf.` :meth:`~torch_scatter.scatter_add`).
The second return tensor contains index location in :attr:`src` of each
maximum value (known as argmax).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \max(\mathrm{out}_i, \max_j(\mathrm{src}_j))
where :math:`\max_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)
.. testoutput::
tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
"""
if fill_value is None:
op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
fill_value = op(src.dtype).min
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out, index.new_full(out.size(), -1)
return ScatterMax.apply(out, src, index, dim)
import torch
from torch_scatter import scatter_add
def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mean.svg?sanitize=true
:align: center
:width: 400px
|
Averages all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \frac{1}{N_i} \cdot
\sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`N_i` indicates the number of indices
referencing :math:`i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mean
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_mean(src, index, out=out)
print(out)
.. testoutput::
tensor([[0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000],
[1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]])
"""
out = scatter_add(src, index, dim, out, dim_size, fill_value)
count = scatter_add(torch.ones_like(src), index, dim, None, out.size(dim))
return out / count.clamp(min=1)
import torch
from torch_scatter.utils.gen import gen
class ScatterMin(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1)
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_min(src, index, out, arg, dim)
else:
torch.ops.torch_scatter_cpu.scatter_min(src, index, out, arg, dim)
ctx.mark_dirty(out)
ctx.dim = dim
ctx.save_for_backward(index, arg)
return out, arg
@staticmethod
def backward(ctx, grad_out, grad_arg):
index, arg = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
size = list(index.size())
size[ctx.dim] += 1
grad_src = grad_out.new_zeros(size)
grad_src.scatter_(ctx.dim, arg.detach() + 1, grad_out)
grad_src = grad_src.narrow(ctx.dim, 1, index.size(ctx.dim))
return None, grad_src, None, None
def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/min.svg?sanitize=true
:align: center
:width: 400px
|
Minimizes all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions minimize** (`cf.` :meth:`~torch_scatter.scatter_add`).
The second return tensor contains index location in :attr:`src` of each
minimum value (known as argmin).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \min(\mathrm{out}_i, \min_j(\mathrm{src}_j))
where :math:`\min_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the greatest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_min
src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.tensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out, argmin = scatter_min(src, index, out=out)
print(out)
print(argmin)
.. testoutput::
tensor([[ 0., 0., -4., -3., -2., 0.],
[-2., -4., -3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
"""
if fill_value is None:
op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
fill_value = op(src.dtype).max
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out, index.new_full(out.size(), -1)
return ScatterMin.apply(out, src, index, dim)
import torch
from torch_scatter.utils.gen import gen
class ScatterMul(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_mul(src, index, out, dim)
else:
torch.ops.torch_scatter_cpu.scatter_mul(src, index, out, dim)
ctx.mark_dirty(out)
ctx.save_for_backward(out, src, index)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, grad_out):
out, src, index = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = (grad_out * out).gather(ctx.dim, index) / src
return None, grad_src, None, None
def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mul.svg?sanitize=true
:align: center
:width: 400px
|
Multiplies all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions multiply** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j \mathrm{src}_j
where :math:`\prod_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`1`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mul
src = torch.Tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))
out = scatter_mul(src, index, out=out)
print(out)
.. testoutput::
tensor([[1., 1., 4., 3., 6., 0.],
[6., 4., 8., 1., 1., 1.]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMul.apply(out, src, index, dim)
...@@ -48,6 +48,75 @@ def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -48,6 +48,75 @@ def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`. If
multiple indices reference the same location, their **contributions add**.
Formally, if :attr:`src` and :attr:`index` are n-dimensional tensors with
size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` and
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
do not match.
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_add
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_add(src, index, out=out)
print(out)
.. testoutput::
tensor([[0., 0., 4., 3., 3., 0.],
[2., 4., 4., 0., 0., 0.]])
"""
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size) return scatter_sum(src, index, dim, out, dim_size)
elif reduce == 'mean': elif reduce == 'mean':
......
import torch
from torch_scatter.helpers import min_value, max_value
class SegmentCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out, dim_size, reduce):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.src_size = list(src.size())
fill_value = 0
if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size())
size[index.dim() - 1] = dim_size
if reduce == 'min':
fill_value = max_value(src.dtype)
elif reduce == 'max':
fill_value = min_value(src.dtype)
out = src.new_full(size, fill_value)
if src.is_cuda:
out, arg_out = torch.ops.torch_scatter_cuda.segment_coo(
src, index, out, reduce)
else:
out, arg_out = torch.ops.torch_scatter_cpu.segment_coo(
src, index, out, reduce)
if fill_value != 0:
out.masked_fill_(out == fill_value, 0)
ctx.save_for_backward(index, arg_out)
if reduce == 'min' or reduce == 'max':
return out, arg_out
else:
return out
@staticmethod
def backward(ctx, grad_out, *args):
(index, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'sum' or ctx.reduce == 'add':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
count = arg_out # Gets pre-computed on GPU but not on CPU.
if count is None:
size = list(index.size())
size[-1] = grad_out.size(index.dim() - 1)
count = torch.ops.torch_scatter_cpu.segment_coo(
torch.ones_like(index, dtype=grad_out.dtype), index,
grad_out.new_zeros(size), 'sum')[0].clamp_(min=1)
if grad_out.is_cuda:
count = torch.ops.torch_scatter_cuda.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
else:
count = torch.ops.torch_scatter_cpu.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[index.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
index.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(index.dim() - 1, 0,
src_size[index.dim() - 1] - 1)
return grad_src, None, None, None, None
class SegmentCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out, reduce):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.src_size = list(src.size())
if src.is_cuda:
out, arg_out = torch.ops.torch_scatter_cuda.segment_csr(
src, indptr, out, reduce)
else:
out, arg_out = torch.ops.torch_scatter_cpu.segment_csr(
src, indptr, out, reduce)
ctx.save_for_backward(indptr, arg_out)
return out if arg_out is None else (out, arg_out)
@staticmethod
def backward(ctx, grad_out, *args):
(indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'sum' or ctx.reduce == 'add':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
count = (indptr2 - indptr1).to(grad_src.dtype)
if grad_out.is_cuda:
count = torch.ops.torch_scatter_cuda.gather_csr(
count, indptr,
count.new_empty(src_size[:indptr.dim()]))
else:
count = torch.ops.torch_scatter_cpu.gather_csr(
count, indptr,
count.new_empty(src_size[:indptr.dim()]))
for _ in range(grad_out.dim() - indptr.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[indptr.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
indptr.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(indptr.dim() - 1, 0,
src_size[indptr.dim() - 1] - 1)
return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce="sum"):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/segment_coo.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along the last dimension of
:attr:`index`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
**to be sorted** along dimension :obj:`index.dim() - 1`.
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
Args:
src (Tensor): The source tensor.
index (LongTensor): The sorted indices of elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension
:obj:`index.dim() - 1`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_coo
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
return SegmentCOO.apply(src, index, out, dim_size, reduce)
def segment_csr(src, indptr, out=None, reduce="sum"):
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
:attr:`indptr`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
corresponding range index in :attr:`indptr` for dimension
:obj:`indptr.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**.
Args:
src (Tensor): The source tensor.
indptr (LongTensor): The index pointers between elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
return SegmentCSR.apply(src, indptr, out, reduce)
...@@ -49,6 +49,94 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor, ...@@ -49,6 +49,94 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/segment_coo.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along the last dimension of
:attr:`index`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
**to be sorted** along dimension :obj:`index.dim() - 1`.
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
Args:
src (Tensor): The source tensor.
index (LongTensor): The sorted indices of elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension
:obj:`index.dim() - 1`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_coo
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
return segment_sum_coo(src, index, out, dim_size) return segment_sum_coo(src, index, out, dim_size)
elif reduce == 'mean': elif reduce == 'mean':
......
...@@ -43,6 +43,73 @@ def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor, ...@@ -43,6 +43,73 @@ def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor,
def segment_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
:attr:`indptr`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
corresponding range index in :attr:`indptr` for dimension
:obj:`indptr.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**.
Args:
src (Tensor): The source tensor.
indptr (LongTensor): The index pointers between elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
return segment_sum_csr(src, indptr, out) return segment_sum_csr(src, indptr, out)
elif reduce == 'mean': elif reduce == 'mean':
......
import torch
from torch_scatter import scatter_add
from torch_scatter.utils.gen import gen
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/std.svg?sanitize=true
:align: center
:width: 400px
|
Computes the standard-deviation from all values from the :attr:`src` tensor
into :attr:`out` at the indices specified in the :attr:`index` tensor along
a given axis :attr:`dim` (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \sqrt{\frac{\sum_j {\left( x_j - \overline{x}_i
\right)}^2}{N_i - 1}}
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`N_i` and :math:`\overline{x}_i`
indicate the number of indices referencing :math:`i` and their mean value,
respectively.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
unbiased (bool, optional): If set to :obj:`False`, then the standard-
deviation will be calculated via the biased estimator.
(default: :obj:`True`)
:rtype: :class:`Tensor`
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
tmp = None if out is None else out.clone().fill_(0)
tmp = scatter_add(src, index, dim, tmp, dim_size)
count = None if out is None else out.clone().fill_(0)
count = scatter_add(torch.ones_like(src), index, dim, count, dim_size)
mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_add(var, index, dim, out, dim_size)
out = out / (count - 1 if unbiased else count).clamp(min=1)
out = torch.sqrt(out)
return out
from torch_scatter import scatter_add
def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/sub.svg?sanitize=true
:align: center
:width: 400px
|
Subtracts all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**negated contributions add** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i - \sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_sub
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_sub(src, index, out=out)
print(out)
.. testoutput::
tensor([[ 0., 0., -4., -3., -3., 0.],
[-2., -4., -4., 0., 0., 0.]])
"""
return scatter_add(src.neg(), index, dim, out, dim_size, fill_value)
from __future__ import division
from itertools import repeat
import torch
def maybe_dim_size(index, dim_size=None):
if dim_size is not None:
return dim_size
dim = index.max().item() + 1 if index.numel() > 0 else 0
return int(dim)
def broadcast(src, index, dim):
dim = range(src.dim())[dim] # Get real dim value.
if index.dim() == 1:
index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim)
if index.numel() > 0:
index = index.view(index_size).expand_as(src)
else: # pragma: no cover
# PyTorch has a bug when view is used on zero-element tensors.
index = src.new_empty(index_size, dtype=torch.long)
# Broadcasting capabilties: Expand dimensions to match.
if src.dim() != index.dim():
raise ValueError(
('Number of dimensions of src and index tensor do not match, '
'got {} and {}').format(src.dim(), index.dim()))
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
return src, index
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, index = broadcast(src, index, dim)
dim = range(src.dim())[dim] # Get real dim value.
# Generate output tensor if not given.
if out is None:
out_size = list(src.size())
dim_size = maybe_dim_size(index, dim_size)
out_size[dim] = dim_size
out = src.new_full(out_size, fill_value)
return src, out, index, dim
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment