unique.py 366 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import numpy as np

if torch.cuda.is_available():
    import unique_cuda


def unique(src):
    src = src.contiguous().view(-1)

    if src.is_cuda:
        out, perm = unique_cuda.unique(src)
    else:
        out, perm = np.unique(src.numpy(), return_index=True)
        out, perm = torch.from_numpy(out), torch.from_numpy(perm)

    return out, perm