"docs/EN/source/vscode:/vscode.git/clone" did not exist on "67eeafcf71d8abe7997229fc784288dc6c3802f5"
grid.py 1.66 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
4
5
6


def grid_cluster(position, size, batch=None):
rusty1s's avatar
rusty1s committed
7
8
9
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
13
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
23
24

    # Translate to minimal positive positions.
25
26
    min = position.min(dim=-2, keepdim=True)[0]
    position = position - min
rusty1s's avatar
rusty1s committed
27

28
29
30
31
32
33
34
    # Compute cluster count for each dimension.
    max = position.max(dim=0)[0]
    while max.dim() > 1:
        max = max.max(dim=0)[0]
    c_max = torch.floor(max.double() / size.double() + 1).long()
    c_max = torch.clamp(c_max, min=1)
    C = c_max.prod()
rusty1s's avatar
rusty1s committed
35
36

    # Generate cluster tensor.
37
38
39
    s = list(position.size())
    s[-1] = 1
    cluster = c_max.new(torch.Size(s))
rusty1s's avatar
rusty1s committed
40
41

    # Fill cluster tensor and reshape.
rusty1s's avatar
rusty1s committed
42
    size = size.type_as(position)
rusty1s's avatar
rusty1s committed
43
    func = get_func('grid', position)
44
    func(C, cluster, position, size, c_max)
rusty1s's avatar
rusty1s committed
45
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
46
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
    if batch is None:
rusty1s's avatar
rusty1s committed
49
        return cluster
rusty1s's avatar
rusty1s committed
50
    else:
51
        batch = (u / c_max[1:].prod()).long()
rusty1s's avatar
rusty1s committed
52
        return cluster, batch