grid.py 3.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rusty1s committed
5
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
def _preprocess(position, size, batch=None, start=None):
    size = size.type_as(position)

rusty1s's avatar
rusty1s committed
11
12
13
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
    # Translate to minimal positive positions if no start was passed.
    if start is None:
        position = position - position.min(dim=-2, keepdim=True)[0]
    else:
        position = position - start
        assert position.min() >= 0, (
            'Passed origin resulting in unallowed negative positions')

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
    return position, size

rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
def _minimal_cluster_size(position, size):
40
41
42
    max = position.max(dim=0)[0]
    while max.dim() > 1:
        max = max.max(dim=0)[0]
rusty1s's avatar
rusty1s committed
43
44
    cluster_size = (max / size).long() + 1
    return cluster_size
rusty1s's avatar
rusty1s committed
45
46


rusty1s's avatar
rusty1s committed
47
48
49
50
def _fixed_cluster_size(position, size, batch=None, end=None):
    if end is None:
        return _minimal_cluster_size(position, size)

rusty1s's avatar
rusty1s committed
51
    eps = 0.000001  # Simulate [start, end) interval.
rusty1s's avatar
rusty1s committed
52
53
54
55
    if batch is None:
        cluster_size = ((end / size).float() - eps).long() + 1
    else:
        cluster_size = ((end / size[1:]).float() - eps).long() + 1
rusty1s's avatar
rusty1s committed
56
57
        max_batch = cluster_size.new(1).fill_(batch.max() + 1)
        cluster_size = torch.cat([max_batch, cluster_size], dim=0)
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66

    return cluster_size


def _grid_cluster(position, size, cluster_size):
    C = cluster_size.prod()
    cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
    cluster = cluster.unsqueeze(dim=-1)

rusty1s's avatar
rusty1s committed
67
    func = get_func('grid', position)
rusty1s's avatar
rusty1s committed
68
69
    func(C, cluster, position, size, cluster_size)

rusty1s's avatar
rusty1s committed
70
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
71
    return cluster, C
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
75
76
77
def sparse_grid_cluster(position, size, batch=None, start=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _minimal_cluster_size(position, size)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
78
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    if batch is None:
        return cluster
    else:
        batch = u / (C // cluster_size[0])
        return cluster, batch


def dense_grid_cluster(position, size, batch=None, start=None, end=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _fixed_cluster_size(position, size, batch, end)
    cluster, C = _grid_cluster(position, size, cluster_size)

    if batch is None:
        return cluster, C
    else:
        C = C // cluster_size[0]
        return cluster, C