segment.py 946 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch

from torch_scatter.utils.gen import gen
rusty1s's avatar
rusty1s committed
4
from torch_scatter.add import scatter_add
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13

if torch.cuda.is_available():
    import torch_scatter.segment_cuda


def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
    src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
    if src.size(dim) == 0:  # pragma: no cover
        return out
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

    if not src.is_cuda:
        return scatter_add(src, index, dim, out, dim_size, fill_value)

    # index = index.transpose(dim, -1).contiguous()
    # src = src.transpose(dim, -1).contiguous()
    # out = out.transpose(dim, -1).contiguous()
    # print(index)
    # print(src)

    torch_scatter.segment_cuda.segment_add_thrust(src, index, out)

    # out = out.transpose(dim, -1).contiguous()
    # key = key.transpose(dim, -1).contiguous()

    return out


rusty1s's avatar
rusty1s committed
32
33
def segment_add2(src, indptr):
    return torch_scatter.segment_cuda.segment_add_csr(src, indptr)