sample.py 4.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch
from torch_geometric.utils import to_undirected

from sample_cuda import farthest_point_sampling, query_radius, query_knn


def batch_slices(batch, sizes=False, include_ends=True):
    """
    Calculates size, start and end indices for each element in a batch.
    """
rusty1s's avatar
rusty1s committed
11
    size = torch.scatter_add_(torch.ones_like(batch), batch)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    cumsum = torch.cumsum(size, dim=0)
    starts = cumsum - size
    ends = cumsum - 1

    slices = starts
    if include_ends:
        slices = torch.stack([starts, ends], dim=1).view(-1)

    if sizes:
        return slices, size
    return slices


def sample_farthest(batch, pos, num_sampled, random_start=False, index=False):
rusty1s's avatar
rusty1s committed
26
27
28
29
    """Samples a specified number of points for each element in a batch using
    farthest iterative point sampling and returns a mask (or indices) for the
    sampled points. If there are less than num_sampled points in a point cloud
    all points are returned.
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    """
    if not pos.is_cuda or not batch.is_cuda:
        raise NotImplementedError

    assert pos.is_contiguous() and batch.is_contiguous()

    slices, sizes = batch_slices(batch, sizes=True)
    batch_size = batch.max().item() + 1

    if random_start:
        random = torch.rand(batch_size, device=slices.device)
        start_points = (sizes.float() * random).long()
    else:
        start_points = torch.zeros_like(sizes)

    idx = farthest_point_sampling(batch_size, slices, pos, num_sampled,
                                  start_points)
    # Remove invalid indices
    idx = idx[idx != -1]

    if index:
        return idx
    mask = torch.zeros(pos.size(0), dtype=torch.uint8, device=pos.device)
    mask[idx] = 1
    return mask


def radius_query_edges(batch,
                       pos,
                       query_batch,
                       query_pos,
                       radius,
                       max_num_neighbors=128,
                       include_self=True,
                       undirected=False):
    if not pos.is_cuda:
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
67
68
69
70
    assert pos.is_cuda and batch.is_cuda
    assert query_pos.is_cuda and query_batch.is_cuda
    assert pos.is_contiguous() and batch.is_contiguous()
    assert query_pos.is_contiguous() and query_batch.is_contiguous()
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    slices, sizes = batch_slices(batch, sizes=True)
    batch_size = batch.max().item() + 1
    query_slices = batch_slices(query_batch)

    max_num_neighbors = min(max_num_neighbors, sizes.max().item())
    idx, cnt = query_radius(batch_size, slices, query_slices, pos, query_pos,
                            radius, max_num_neighbors, include_self)

    # Convert to edges
    view = idx.view(-1)
    row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
    row = row.view(-1, 1).repeat(1, max_num_neighbors).view(-1)

    # Remove invalid indices
    row = row[view != -1]
    col = view[view != -1]
    if col.size(0) == 0:
        return col

    edge_index = torch.stack([row, col], dim=0)
    if undirected:
        return to_undirected(edge_index, query_pos.size(0))
    return edge_index


def radius_graph(batch,
                 pos,
                 radius,
                 max_num_neighbors=128,
                 include_self=False,
                 undirected=False):
    return radius_query_edges(batch, pos, batch, pos, radius,
                              max_num_neighbors, include_self, undirected)


def knn_query_edges(batch,
                    pos,
                    query_batch,
                    query_pos,
                    num_neighbors,
                    include_self=True,
                    undirected=False):
    if not pos.is_cuda:
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
116
117
118
119
    assert pos.is_cuda and batch.is_cuda
    assert query_pos.is_cuda and query_batch.is_cuda
    assert pos.is_contiguous() and batch.is_contiguous()
    assert query_pos.is_contiguous() and query_batch.is_contiguous()
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    slices, sizes = batch_slices(batch, sizes=True)
    batch_size = batch.max().item() + 1
    query_slices = batch_slices(query_batch)

    assert (sizes < num_neighbors).sum().item() == 0

    idx, dists = query_knn(batch_size, slices, query_slices, pos, query_pos,
                           num_neighbors, include_self)

    # Convert to edges
    view = idx.view(-1)

    row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
    row = row.view(-1, 1).repeat(1, num_neighbors).view(-1)

    # Remove invalid indices
    row = row[view != -1]
    col = view[view != -1]

    edge_index = torch.stack([row, col], dim=0)
    if undirected:
        return to_undirected(edge_index, query_pos.size(0))
    return edge_index


def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False):
    return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self,
                           undirected)