consecutive.py 1.1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
from torch_unique import unique


rusty1s's avatar
rusty1s committed
5
6
def _get_type(max_value, cuda):
    if max_value <= 255:
rusty1s's avatar
rusty1s committed
7
        return torch.cuda.ByteTensor if cuda else torch.ByteTensor
rusty1s's avatar
rusty1s committed
8
    elif max_value <= 32767:  # pragma: no cover
rusty1s's avatar
rusty1s committed
9
        return torch.cuda.ShortTensor if cuda else torch.ShortTensor
rusty1s's avatar
rusty1s committed
10
    elif max_value <= 2147483647:  # pragma: no cover
rusty1s's avatar
rusty1s committed
11
        return torch.cuda.IntTensor if cuda else torch.IntTensor
rusty1s's avatar
rusty1s committed
12
    else:  # pragma: no cover
rusty1s's avatar
rusty1s committed
13
14
15
        return torch.cuda.LongTensor if cuda else torch.LongTensor


rusty1s's avatar
rusty1s committed
16
def consecutive(x):
rusty1s's avatar
rusty1s committed
17
18
19
    initial_size = x.size()

    # Compute unique vector.
rusty1s's avatar
rusty1s committed
20
    u = unique(x.view(-1))
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27
28
29
30
31
32
33

    # Compute mask with mask[u[0]] = 0, mask[u[1]] = 1, ...
    # As mask can get very big (dependent on the maximum value in `x`, we want
    # to take the least possible amount of space on disk (`_get_type`).
    max_value = u[-1] + 1
    mask = _get_type(u.size(0), x.is_cuda)(max_value)
    mask[u] = torch.arange(0, u.size(0), out=mask.new())

    # Select the values in `mask` based on `x` and reshape to initial size.
    x = mask[x.view(-1)]
    x = x.view(initial_size)
    x = x.long()

rusty1s's avatar
rusty1s committed
34
    return x