segment.py 684 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from torch_scatter.add import scatter_add
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9

if torch.cuda.is_available():
    import torch_scatter.segment_cuda


def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
rusty1s's avatar
rusty1s committed
10
    return scatter_add(src, index, dim, out, dim_size, fill_value)
rusty1s's avatar
rusty1s committed
11
12


13
14
def segment_add_csr(src, indptr, out=None):
    return torch_scatter.segment_cuda.segment_add_csr(src, indptr, out)
rusty1s's avatar
rusty1s committed
15
16
17
18


def segment_add_coo(src, index, dim_size=None):
    dim_size = index.max().item() + 1 if dim_size is None else dim_size
rusty1s's avatar
rusty1s committed
19
20
21
    size = list(src.size())
    size[index.dim() - 1] = dim_size
    out = src.new_zeros(size)
rusty1s's avatar
rusty1s committed
22
23
    torch_scatter.segment_cuda.segment_add_coo(src, index, out)
    return out