unique.py 392 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
import numpy as np

if torch.cuda.is_available():
5
    import torch_sparse.unique_cuda
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11


def unique(src):
    src = src.contiguous().view(-1)

    if src.is_cuda:
12
        out, perm = torch_sparse.unique_cuda.unique(src)
rusty1s's avatar
rusty1s committed
13
14
15
16
17
    else:
        out, perm = np.unique(src.numpy(), return_index=True)
        out, perm = torch.from_numpy(out), torch.from_numpy(perm)

    return out, perm