grid.py 3.18 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
import torch

5
from .utils import get_dynamic_func, consecutive
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
def _preprocess(position, size, batch=None, start=None):
    size = size.type_as(position)

rusty1s's avatar
rusty1s committed
11
12
13
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
    # Translate to minimal positive positions if no start was passed.
    if start is None:
rusty1s's avatar
rusty1s committed
21
22
23
        min = []
        for i in range(position.size(-1)):
            min.append(position[:, i].min())
rusty1s's avatar
rusty1s committed
24
        start = position.new(min)
rusty1s's avatar
rusty1s committed
25
        position = position - position.new(min)
rusty1s's avatar
rusty1s committed
26
    else:
rusty1s's avatar
rusty1s committed
27
28
        assert start.numel() == size.numel(), (
            'Start tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
29
        position = position - start.type_as(position)
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
    return position, size, start
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
def _minimal_cluster_size(position, size):
rusty1s's avatar
rusty1s committed
44
45
46
47
    max = []
    for i in range(position.size(-1)):
        max.append(position[:, i].max())
    cluster_size = (size.new(max) / size).long() + 1
rusty1s's avatar
rusty1s committed
48
    return cluster_size
rusty1s's avatar
rusty1s committed
49
50


rusty1s's avatar
rusty1s committed
51
def _fixed_cluster_size(position, size, start, batch=None, end=None):
rusty1s's avatar
rusty1s committed
52
53
54
    if end is None:
        return _minimal_cluster_size(position, size)

rusty1s's avatar
rusty1s committed
55
    end = end.type_as(size) - start.type_as(size)
rusty1s's avatar
rusty1s committed
56
    eps = 0.000001  # Simulate [start, end) interval.
rusty1s's avatar
rusty1s committed
57
58
59
60
    if batch is None:
        cluster_size = ((end / size).float() - eps).long() + 1
    else:
        cluster_size = ((end / size[1:]).float() - eps).long() + 1
rusty1s's avatar
rusty1s committed
61
62
        max_batch = cluster_size.new(1).fill_(batch.max() + 1)
        cluster_size = torch.cat([max_batch, cluster_size], dim=0)
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70
71

    return cluster_size


def _grid_cluster(position, size, cluster_size):
    C = cluster_size.prod()
    cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
    cluster = cluster.unsqueeze(dim=-1)

72
    func = get_dynamic_func('grid', position)
rusty1s's avatar
rusty1s committed
73
74
    func(C, cluster, position, size, cluster_size)

rusty1s's avatar
rusty1s committed
75
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
76
    return cluster, C
rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
def sparse_grid_cluster(position, size, batch=None, start=None):
rusty1s's avatar
rusty1s committed
80
    position, size, start = _preprocess(position, size, batch, start)
rusty1s's avatar
rusty1s committed
81
82
    cluster_size = _minimal_cluster_size(position, size)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
83
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92

    if batch is None:
        return cluster
    else:
        batch = u / (C // cluster_size[0])
        return cluster, batch


def dense_grid_cluster(position, size, batch=None, start=None, end=None):
rusty1s's avatar
rusty1s committed
93
94
    position, size, start = _preprocess(position, size, batch, start)
    cluster_size = _fixed_cluster_size(position, size, start, batch, end)
rusty1s's avatar
rusty1s committed
95
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
96
    return cluster, C