consecutive.py 862 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
from torch_unique import unique


rusty1s's avatar
tests  
rusty1s committed
5
def _get_type(max, cuda):
6
    if max <= 255:
rusty1s's avatar
rusty1s committed
7
        return torch.cuda.ByteTensor if cuda else torch.ByteTensor
8
    elif max <= 32767:  # pragma: no cover
rusty1s's avatar
rusty1s committed
9
        return torch.cuda.ShortTensor if cuda else torch.ShortTensor
10
    elif max <= 2147483647:  # pragma: no cover
rusty1s's avatar
rusty1s committed
11
        return torch.cuda.IntTensor if cuda else torch.IntTensor
rusty1s's avatar
rusty1s committed
12
    else:  # pragma: no cover
rusty1s's avatar
rusty1s committed
13
14
15
        return torch.cuda.LongTensor if cuda else torch.LongTensor


rusty1s's avatar
tests  
rusty1s committed
16
def consecutive(tensor, return_unique=False):
rusty1s's avatar
rusty1s committed
17
18
    size = tensor.size()
    u = unique(tensor.view(-1))
rusty1s's avatar
rusty1s committed
19
20
    len = u[-1] + 1
    max = u.size(0)
rusty1s's avatar
tests  
rusty1s committed
21
    type = _get_type(max, tensor.is_cuda)
rusty1s's avatar
rusty1s committed
22
23
    arg = type(len)
    arg[u] = torch.arange(0, max, out=type(max))
rusty1s's avatar
rusty1s committed
24
    tensor = arg[tensor.view(-1)]
rusty1s's avatar
tests  
rusty1s committed
25
26
27
    tensor = tensor.view(size).long()

    return (tensor, u) if return_unique else tensor