Unverified Commit c01f9bae authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #105 from rusty1s/traceable

[WIP] tracebale functions
parents 2520670a 02a47c46
import torch
from torch_scatter.utils.gen import gen
class ScatterMul(torch.autograd.Function):
@staticmethod
def forward(ctx, out, src, index, dim):
if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_mul(src, index, out, dim)
else:
torch.ops.torch_scatter_cpu.scatter_mul(src, index, out, dim)
ctx.mark_dirty(out)
ctx.save_for_backward(out, src, index)
ctx.dim = dim
return out
@staticmethod
def backward(ctx, grad_out):
out, src, index = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[1]:
grad_src = (grad_out * out).gather(ctx.dim, index) / src
return None, grad_src, None, None
def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/mul.svg?sanitize=true
:align: center
:width: 400px
|
Multiplies all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**contributions multiply** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i \cdot \prod_j \mathrm{src}_j
where :math:`\prod_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`1`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_mul
src = torch.Tensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))
out = scatter_mul(src, index, out=out)
print(out)
.. testoutput::
tensor([[1., 1., 4., 3., 6., 0.],
[6., 4., 8., 1., 1., 1.]])
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
if src.size(dim) == 0: # pragma: no cover
return out
return ScatterMul.apply(out, src, index, dim)
import os.path as osp
from typing import Optional, Tuple
import torch
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
@torch.jit.script
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
:param src: The source tensor.
:param index: The indices of elements to scatter.
:param dim: The axis along which to index. (default: :obj:`-1`)
:param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import scatter
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
import torch
from torch_scatter.helpers import min_value, max_value
class SegmentCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out, dim_size, reduce):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.src_size = list(src.size())
fill_value = 0
if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size())
size[index.dim() - 1] = dim_size
if reduce == 'min':
fill_value = max_value(src.dtype)
elif reduce == 'max':
fill_value = min_value(src.dtype)
out = src.new_full(size, fill_value)
if src.is_cuda:
out, arg_out = torch.ops.torch_scatter_cuda.segment_coo(
src, index, out, reduce)
else:
out, arg_out = torch.ops.torch_scatter_cpu.segment_coo(
src, index, out, reduce)
if fill_value != 0:
out.masked_fill_(out == fill_value, 0)
ctx.save_for_backward(index, arg_out)
if reduce == 'min' or reduce == 'max':
return out, arg_out
else:
return out
@staticmethod
def backward(ctx, grad_out, *args):
(index, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'sum' or ctx.reduce == 'add':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_coo(
grad_out, index, grad_out.new_empty(src_size))
count = arg_out # Gets pre-computed on GPU but not on CPU.
if count is None:
size = list(index.size())
size[-1] = grad_out.size(index.dim() - 1)
count = torch.ops.torch_scatter_cpu.segment_coo(
torch.ones_like(index, dtype=grad_out.dtype), index,
grad_out.new_zeros(size), 'sum')[0].clamp_(min=1)
if grad_out.is_cuda:
count = torch.ops.torch_scatter_cuda.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
else:
count = torch.ops.torch_scatter_cpu.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[index.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
index.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(index.dim() - 1, 0,
src_size[index.dim() - 1] - 1)
return grad_src, None, None, None, None
class SegmentCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out, reduce):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.src_size = list(src.size())
if src.is_cuda:
out, arg_out = torch.ops.torch_scatter_cuda.segment_csr(
src, indptr, out, reduce)
else:
out, arg_out = torch.ops.torch_scatter_cpu.segment_csr(
src, indptr, out, reduce)
ctx.save_for_backward(indptr, arg_out)
return out if arg_out is None else (out, arg_out)
@staticmethod
def backward(ctx, grad_out, *args):
(indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'sum' or ctx.reduce == 'add':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
if grad_out.is_cuda:
grad_src = torch.ops.torch_scatter_cuda.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
else:
grad_src = torch.ops.torch_scatter_cpu.gather_csr(
grad_out, indptr, grad_out.new_empty(src_size))
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
count = (indptr2 - indptr1).to(grad_src.dtype)
if grad_out.is_cuda:
count = torch.ops.torch_scatter_cuda.gather_csr(
count, indptr,
count.new_empty(src_size[:indptr.dim()]))
else:
count = torch.ops.torch_scatter_cpu.gather_csr(
count, indptr,
count.new_empty(src_size[:indptr.dim()]))
for _ in range(grad_out.dim() - indptr.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[indptr.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
indptr.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(indptr.dim() - 1, 0,
src_size[indptr.dim() - 1] - 1)
return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce="sum"):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/segment_coo.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along the last dimension of
:attr:`index`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
**to be sorted** along dimension :obj:`index.dim() - 1`.
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
Args:
src (Tensor): The source tensor.
index (LongTensor): The sorted indices of elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension
:obj:`index.dim() - 1`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_coo
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
return SegmentCOO.apply(src, index, out, dim_size, reduce)
def segment_csr(src, indptr, out=None, reduce="sum"):
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
:attr:`indptr`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
corresponding range index in :attr:`indptr` for dimension
:obj:`indptr.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**.
Args:
src (Tensor): The source tensor.
indptr (LongTensor): The index pointers between elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
return SegmentCSR.apply(src, indptr, out, reduce)
import os.path as osp
from typing import Optional, Tuple
import torch
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_coo.so'))
@torch.jit.script
def segment_sum_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
@torch.jit.script
def segment_add_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
@torch.jit.script
def segment_mean_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size)
@torch.jit.script
def segment_min_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size)
@torch.jit.script
def segment_max_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size)
def segment_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/segment_coo.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along the last dimension of
:attr:`index`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
**to be sorted** along dimension :obj:`index.dim() - 1`.
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
:param src: The source tensor.
:param index: The sorted indices of elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
:param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :obj:`index.dim() - 1`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import segment_coo
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return segment_sum_coo(src, index, out, dim_size)
elif reduce == 'mean':
return segment_mean_coo(src, index, out, dim_size)
elif reduce == 'min':
return segment_min_coo(src, index, out, dim_size)[0]
elif reduce == 'max':
return segment_max_coo(src, index, out, dim_size)[0]
else:
raise ValueError
@torch.jit.script
def gather_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_coo(src, index, out)
import os.path as osp
from typing import Optional, Tuple
import torch
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_csr.so'))
@torch.jit.script
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
@torch.jit.script
def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
@torch.jit.script
def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out)
@torch.jit.script
def segment_min_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_csr(src, indptr, out)
@torch.jit.script
def segment_max_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_csr(src, indptr, out)
def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
:attr:`indptr`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
corresponding range index in :attr:`indptr` for dimension
:obj:`indptr.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
.. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**.
:param src: The source tensor.
:param indptr: The index pointers between elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
:param out: The destination tensor.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return segment_sum_csr(src, indptr, out)
elif reduce == 'mean':
return segment_mean_csr(src, indptr, out)
elif reduce == 'min':
return segment_min_csr(src, indptr, out)[0]
elif reduce == 'max':
return segment_max_csr(src, indptr, out)[0]
else:
raise ValueError
@torch.jit.script
def gather_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_csr(src, indptr, out)
import torch
from torch_scatter import scatter_add
from torch_scatter.utils.gen import gen
def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/std.svg?sanitize=true
:align: center
:width: 400px
|
Computes the standard-deviation from all values from the :attr:`src` tensor
into :attr:`out` at the indices specified in the :attr:`index` tensor along
a given axis :attr:`dim` (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \sqrt{\frac{\sum_j {\left( x_j - \overline{x}_i
\right)}^2}{N_i - 1}}
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`. :math:`N_i` and :math:`\overline{x}_i`
indicate the number of indices referencing :math:`i` and their mean value,
respectively.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
unbiased (bool, optional): If set to :obj:`False`, then the standard-
deviation will be calculated via the biased estimator.
(default: :obj:`True`)
:rtype: :class:`Tensor`
"""
src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0)
tmp = None if out is None else out.clone().fill_(0)
tmp = scatter_add(src, index, dim, tmp, dim_size)
count = None if out is None else out.clone().fill_(0)
count = scatter_add(torch.ones_like(src), index, dim, count, dim_size)
mean = tmp / count.clamp(min=1)
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_add(var, index, dim, out, dim_size)
out = out / (count - 1 if unbiased else count).clamp(min=1)
out = torch.sqrt(out)
return out
from torch_scatter import scatter_add
def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/sub.svg?sanitize=true
:align: center
:width: 400px
|
Subtracts all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.If multiple indices reference the same location, their
**negated contributions add** (`cf.` :meth:`~torch_scatter.scatter_add`).
For one-dimensional tensors, the operation computes
.. math::
\mathrm{out}_i = \mathrm{out}_i - \sum_j \mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
:rtype: :class:`Tensor`
.. testsetup::
import torch
.. testcode::
from torch_scatter import scatter_sub
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out = scatter_sub(src, index, out=out)
print(out)
.. testoutput::
tensor([[ 0., 0., -4., -3., -3., 0.],
[-2., -4., -4., 0., 0., 0.]])
"""
return scatter_add(src.neg(), index, dim, out, dim_size, fill_value)
from __future__ import division
from itertools import repeat
import torch
def maybe_dim_size(index, dim_size=None):
if dim_size is not None:
return dim_size
dim = index.max().item() + 1 if index.numel() > 0 else 0
return int(dim)
def broadcast(src, index, dim):
dim = range(src.dim())[dim] # Get real dim value.
if index.dim() == 1:
index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim)
if index.numel() > 0:
index = index.view(index_size).expand_as(src)
else: # pragma: no cover
# PyTorch has a bug when view is used on zero-element tensors.
index = src.new_empty(index_size, dtype=torch.long)
# Broadcasting capabilties: Expand dimensions to match.
if src.dim() != index.dim():
raise ValueError(
('Number of dimensions of src and index tensor do not match, '
'got {} and {}').format(src.dim(), index.dim()))
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
return src, index
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, index = broadcast(src, index, dim)
dim = range(src.dim())[dim] # Get real dim value.
# Generate output tensor if not given.
if out is None:
out_size = list(src.size())
dim_size = maybe_dim_size(index, dim_size)
out_size[dim] = dim_size
out = src.new_full(out_size, fill_value)
return src, out, index, dim
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment