grid.py 3.21 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rename  
rusty1s committed
5
from .utils.ffi import get_typed_func
rusty1s's avatar
rusty1s committed
6
from .utils.consecutive import consecutive
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
11
def _preprocess(position, size, batch=None, start=None):
    size = size.type_as(position)

rusty1s's avatar
rusty1s committed
12
13
14
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
18
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
    # Translate to minimal positive positions if no start was passed.
    if start is None:
rusty1s's avatar
rusty1s committed
22
23
24
        min = []
        for i in range(position.size(-1)):
            min.append(position[:, i].min())
rusty1s's avatar
rusty1s committed
25
        start = position.new(min)
rusty1s's avatar
rusty1s committed
26
        position = position - position.new(min)
rusty1s's avatar
rusty1s committed
27
    else:
rusty1s's avatar
rusty1s committed
28
29
        assert start.numel() == size.numel(), (
            'Start tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
30
        position = position - start.type_as(position)
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
    return position, size, start
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
def _minimal_cluster_size(position, size):
rusty1s's avatar
rusty1s committed
45
46
47
48
    max = []
    for i in range(position.size(-1)):
        max.append(position[:, i].max())
    cluster_size = (size.new(max) / size).long() + 1
rusty1s's avatar
rusty1s committed
49
    return cluster_size
rusty1s's avatar
rusty1s committed
50
51


rusty1s's avatar
rusty1s committed
52
def _fixed_cluster_size(position, size, start, batch=None, end=None):
rusty1s's avatar
rusty1s committed
53
54
55
    if end is None:
        return _minimal_cluster_size(position, size)

rusty1s's avatar
rusty1s committed
56
    end = end.type_as(size) - start.type_as(size)
rusty1s's avatar
rusty1s committed
57
    eps = 0.000001  # Simulate [start, end) interval.
rusty1s's avatar
rusty1s committed
58
59
60
61
    if batch is None:
        cluster_size = ((end / size).float() - eps).long() + 1
    else:
        cluster_size = ((end / size[1:]).float() - eps).long() + 1
rusty1s's avatar
rusty1s committed
62
63
        max_batch = cluster_size.new(1).fill_(batch.max() + 1)
        cluster_size = torch.cat([max_batch, cluster_size], dim=0)
rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72

    return cluster_size


def _grid_cluster(position, size, cluster_size):
    C = cluster_size.prod()
    cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
    cluster = cluster.unsqueeze(dim=-1)

rusty1s's avatar
rename  
rusty1s committed
73
    func = get_typed_func('grid', position)
rusty1s's avatar
rusty1s committed
74
75
    func(C, cluster, position, size, cluster_size)

rusty1s's avatar
rusty1s committed
76
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
77
    return cluster, C
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
def sparse_grid_cluster(position, size, batch=None, start=None):
rusty1s's avatar
rusty1s committed
81
    position, size, start = _preprocess(position, size, batch, start)
rusty1s's avatar
rusty1s committed
82
83
    cluster_size = _minimal_cluster_size(position, size)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
84
    cluster = consecutive(cluster)
rusty1s's avatar
rusty1s committed
85
86
87
88

    if batch is None:
        return cluster
    else:
rusty1s's avatar
rusty1s committed
89
        # batch = u / (C // cluster_size[0])
rusty1s's avatar
rusty1s committed
90
91
92
93
        return cluster, batch


def dense_grid_cluster(position, size, batch=None, start=None, end=None):
rusty1s's avatar
rusty1s committed
94
95
    position, size, start = _preprocess(position, size, batch, start)
    cluster_size = _fixed_cluster_size(position, size, start, batch, end)
rusty1s's avatar
rusty1s committed
96
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
97
    return cluster, C