Commit 6a2a503e authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #191 failed with stages
import os
import importlib
import os.path as osp
import torch
__version__ = '2.0.9'
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
hip_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_hip', [osp.dirname(__file__)])
cpu_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cpu', [osp.dirname(__file__)])
spec = hip_spec or cpu_spec
if spec is not None:
torch.ops.load_library(spec.origin)
elif os.getenv('BUILD_DOCS', '0') != '1': # pragma: no cover
raise ImportError(f"Could not find module '{library}_cpu' in "
f"{osp.dirname(__file__)}")
else: # pragma: no cover
from .placeholder import cuda_version_placeholder
torch.ops.torch_scatter.cuda_version = cuda_version_placeholder
from .placeholder import scatter_placeholder
torch.ops.torch_scatter.scatter_mul = scatter_placeholder
from .placeholder import scatter_arg_placeholder
torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder
from .placeholder import segment_csr_placeholder
from .placeholder import segment_csr_arg_placeholder
from .placeholder import gather_csr_placeholder
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
torch.ops.torch_scatter.segment_min_csr = segment_csr_arg_placeholder
torch.ops.torch_scatter.segment_max_csr = segment_csr_arg_placeholder
torch.ops.torch_scatter.gather_csr = gather_csr_placeholder
from .placeholder import segment_coo_placeholder
from .placeholder import segment_coo_arg_placeholder
from .placeholder import gather_coo_placeholder
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
torch.ops.torch_scatter.segment_min_coo = segment_coo_arg_placeholder
torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
cuda_version = torch.ops.torch_scatter.cuda_version()
if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
from .scatter import scatter_sum, scatter_add, scatter_mul # noqa
from .scatter import scatter_mean, scatter_min, scatter_max, scatter # noqa
from .segment_csr import segment_sum_csr, segment_add_csr # noqa
from .segment_csr import segment_mean_csr, segment_min_csr # noqa
from .segment_csr import segment_max_csr, segment_csr, gather_csr # noqa
from .segment_coo import segment_sum_coo, segment_add_coo # noqa
from .segment_coo import segment_mean_coo, segment_min_coo # noqa
from .segment_coo import segment_max_coo, segment_coo, gather_coo # noqa
from .composite import scatter_std, scatter_logsumexp # noqa
from .composite import scatter_softmax, scatter_log_softmax # noqa
__all__ = [
'scatter_sum',
'scatter_add',
'scatter_mul',
'scatter_mean',
'scatter_min',
'scatter_max',
'scatter',
'segment_sum_csr',
'segment_add_csr',
'segment_mean_csr',
'segment_min_csr',
'segment_max_csr',
'segment_csr',
'gather_csr',
'segment_sum_coo',
'segment_add_coo',
'segment_mean_coo',
'segment_min_coo',
'segment_max_coo',
'segment_coo',
'gather_coo',
'scatter_std',
'scatter_logsumexp',
'scatter_softmax',
'scatter_log_softmax',
'torch_scatter',
'__version__',
]
from .std import scatter_std
from .logsumexp import scatter_logsumexp
from .softmax import scatter_log_softmax, scatter_softmax
__all__ = [
'scatter_std',
'scatter_logsumexp',
'scatter_softmax',
'scatter_log_softmax',
]
from typing import Optional
import torch
from torch_scatter import scatter_sum, scatter_max
from torch_scatter.utils import broadcast
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
index = broadcast(index, src, dim)
if out is not None:
dim_size = out.size(dim)
else:
if dim_size is None:
dim_size = int(index.max()) + 1
size = list(src.size())
size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_score = src - max_per_src_element
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))
if out is not None:
out = out.sub_(max_value_per_index).exp_()
sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
dim_size)
return sum_per_index.add_(eps).log_().add_(max_value_per_index)
from typing import Optional
import torch
from torch_scatter import scatter_sum, scatter_max
from torch_scatter.utils import broadcast
def scatter_softmax(src: torch.Tensor, index: torch.Tensor,
dim: int = -1,
dim_size: Optional[int] = None) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
index = broadcast(index, src, dim)
max_value_per_index = scatter_max(
src, index, dim=dim, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp_()
sum_per_index = scatter_sum(
recentered_scores_exp, index, dim, dim_size=dim_size)
normalizing_constants = sum_per_index.gather(dim, index)
return recentered_scores_exp.div(normalizing_constants)
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
eps: float = 1e-12,
dim_size: Optional[int] = None) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
index = broadcast(index, src, dim)
max_value_per_index = scatter_max(
src, index, dim=dim, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
sum_per_index = scatter_sum(
recentered_scores.exp(), index, dim, dim_size=dim_size)
normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
return recentered_scores.sub_(normalizing_constants)
from typing import Optional
import torch
from torch_scatter import scatter_sum
from torch_scatter.utils import broadcast
def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True) -> torch.Tensor:
if out is not None:
dim_size = out.size(dim)
if dim < 0:
dim = src.dim() + dim
count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
index = broadcast(index, src, dim)
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
count = broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count)
var = (src - mean.gather(dim, index))
var = var * var
out = scatter_sum(var, index, dim, out, dim_size)
if unbiased:
count = count.sub(1).clamp_(1)
out = out.div(count + 1e-6).sqrt()
return out
from typing import Optional, Tuple
import torch
def cuda_version_placeholder() -> int:
return -1
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def scatter_arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor],
dim_size: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, index
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
def segment_csr_arg_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, indptr
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
def segment_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def segment_coo_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor],
dim_size: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, index
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
from typing import Optional, Tuple
import torch
from .utils import broadcast
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return scatter_sum(src, index, dim, out, dim_size)
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode='floor')
return out
def scatter_min(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
def scatter_max(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1`, although no specific ordering of indices is required.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
:param src: The source tensor.
:param index: The indices of elements to scatter.
:param dim: The axis along which to index. (default: :obj:`-1`)
:param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import scatter
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
from typing import Optional, Tuple
import torch
def segment_sum_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
def segment_add_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
def segment_mean_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size)
def segment_min_coo(
src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size)
def segment_max_coo(
src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size)
def segment_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/segment_coo.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along the last dimension of
:attr:`index`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
**to be sorted** along dimension :obj:`index.dim() - 1`.
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
:param src: The source tensor.
:param index: The sorted indices of elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
:param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :obj:`index.dim() - 1`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import segment_coo
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return segment_sum_coo(src, index, out, dim_size)
elif reduce == 'mean':
return segment_mean_coo(src, index, out, dim_size)
elif reduce == 'min':
return segment_min_coo(src, index, out, dim_size)[0]
elif reduce == 'max':
return segment_max_coo(src, index, out, dim_size)[0]
else:
raise ValueError
def gather_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_coo(src, index, out)
from typing import Optional, Tuple
import torch
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out)
def segment_min_csr(
src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_csr(src, indptr, out)
def segment_max_csr(
src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_csr(src, indptr, out)
def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
:attr:`indptr`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
corresponding range index in :attr:`indptr` for dimension
:obj:`indptr.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-2}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
.. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**.
:param src: The source tensor.
:param indptr: The index pointers between elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
:param out: The destination tensor.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`,
:obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return segment_sum_csr(src, indptr, out)
elif reduce == 'mean':
return segment_mean_csr(src, indptr, out)
elif reduce == 'min':
return segment_min_csr(src, indptr, out)[0]
elif reduce == 'max':
return segment_max_csr(src, indptr, out)[0]
else:
raise ValueError
def gather_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_csr(src, indptr, out)
import torch
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment