grid.py 1.9 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
4
5


rusty1s's avatar
rename  
rusty1s committed
6
def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
rusty1s's avatar
rusty1s committed
7
8
9
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
13
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rename  
rusty1s committed
24
25
    # Translate to minimal positive positions if no origin was passed.
    if origin is None:
rusty1s's avatar
rusty1s committed
26
27
28
        min = position.min(dim=-2, keepdim=True)[0]
        position = position - min
    else:
rusty1s's avatar
rename  
rusty1s committed
29
        position = position + origin
rusty1s's avatar
typo  
rusty1s committed
30
        assert position.min() >= 0, (
rusty1s's avatar
rename  
rusty1s committed
31
            'Passed origin resulting in unallowed negative positions')
rusty1s's avatar
rusty1s committed
32

33
34
35
36
37
38
39
    # Compute cluster count for each dimension.
    max = position.max(dim=0)[0]
    while max.dim() > 1:
        max = max.max(dim=0)[0]
    c_max = torch.floor(max.double() / size.double() + 1).long()
    c_max = torch.clamp(c_max, min=1)
    C = c_max.prod()
rusty1s's avatar
rusty1s committed
40
41

    # Generate cluster tensor.
42
43
44
    s = list(position.size())
    s[-1] = 1
    cluster = c_max.new(torch.Size(s))
rusty1s's avatar
rusty1s committed
45
46

    # Fill cluster tensor and reshape.
rusty1s's avatar
rusty1s committed
47
    size = size.type_as(position)
rusty1s's avatar
rusty1s committed
48
    func = get_func('grid', position)
49
    func(C, cluster, position, size, c_max)
rusty1s's avatar
rusty1s committed
50
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
54
55
56
    if fake_nodes:
        return cluster, C // c_max[0]

    cluster, u = consecutive(cluster)
    return cluster, None if batch is None else (u / (C // c_max[0])).long()