sample.py 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import to_undirected
from torch_geometric.data import Batch
from torch_sparse import coalesce

from sample_cuda import farthest_point_sampling, query_radius, query_knn


def batch_slices(batch, sizes=False, include_ends=True):
    """
    Calculates size, start and end indices for each element in a batch.
    """
    size = scatter_add(torch.ones_like(batch), batch)
    cumsum = torch.cumsum(size, dim=0)
    starts = cumsum - size
    ends = cumsum - 1

    slices = starts
    if include_ends:
        slices = torch.stack([starts, ends], dim=1).view(-1)

    if sizes:
        return slices, size
    return slices


def sample_farthest(batch, pos, num_sampled, random_start=False, index=False):
    """
    Samples a specified number of points for each element in a batch using farthest iterative point sampling and returns
    a mask (or indices) for the sampled points.
    If there are less than num_sampled points in a point cloud all points are returned.
    """
    if not pos.is_cuda or not batch.is_cuda:
        raise NotImplementedError

    assert pos.is_contiguous() and batch.is_contiguous()

    slices, sizes = batch_slices(batch, sizes=True)
    batch_size = batch.max().item() + 1

    if random_start:
        random = torch.rand(batch_size, device=slices.device)
        start_points = (sizes.float() * random).long()
    else:
        start_points = torch.zeros_like(sizes)

    idx = farthest_point_sampling(batch_size, slices, pos, num_sampled,
                                  start_points)
    # Remove invalid indices
    idx = idx[idx != -1]

    if index:
        return idx
    mask = torch.zeros(pos.size(0), dtype=torch.uint8, device=pos.device)
    mask[idx] = 1
    return mask


def radius_query_edges(batch,
                       pos,
                       query_batch,
                       query_pos,
                       radius,
                       max_num_neighbors=128,
                       include_self=True,
                       undirected=False):
    if not pos.is_cuda:
        raise NotImplementedError
    assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda
    assert pos.is_contiguous() and batch.is_contiguous(
    ) and query_pos.is_contiguous() and query_batch.is_contiguous()

    slices, sizes = batch_slices(batch, sizes=True)
    batch_size = batch.max().item() + 1
    query_slices = batch_slices(query_batch)

    max_num_neighbors = min(max_num_neighbors, sizes.max().item())
    idx, cnt = query_radius(batch_size, slices, query_slices, pos, query_pos,
                            radius, max_num_neighbors, include_self)

    # Convert to edges
    view = idx.view(-1)
    row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
    row = row.view(-1, 1).repeat(1, max_num_neighbors).view(-1)

    # Remove invalid indices
    row = row[view != -1]
    col = view[view != -1]
    if col.size(0) == 0:
        return col

    edge_index = torch.stack([row, col], dim=0)
    if undirected:
        return to_undirected(edge_index, query_pos.size(0))
    return edge_index


def radius_graph(batch,
                 pos,
                 radius,
                 max_num_neighbors=128,
                 include_self=False,
                 undirected=False):
    return radius_query_edges(batch, pos, batch, pos, radius,
                              max_num_neighbors, include_self, undirected)


def knn_query_edges(batch,
                    pos,
                    query_batch,
                    query_pos,
                    num_neighbors,
                    include_self=True,
                    undirected=False):
    if not pos.is_cuda:
        raise NotImplementedError
    assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda
    assert pos.is_contiguous() and batch.is_contiguous(
    ) and query_pos.is_contiguous() and query_batch.is_contiguous()

    slices, sizes = batch_slices(batch, sizes=True)
    batch_size = batch.max().item() + 1
    query_slices = batch_slices(query_batch)

    assert (sizes < num_neighbors).sum().item() == 0

    idx, dists = query_knn(batch_size, slices, query_slices, pos, query_pos,
                           num_neighbors, include_self)

    # Convert to edges
    view = idx.view(-1)

    row = torch.arange(query_pos.size(0), dtype=torch.long, device=pos.device)
    row = row.view(-1, 1).repeat(1, num_neighbors).view(-1)

    # Remove invalid indices
    row = row[view != -1]
    col = view[view != -1]

    edge_index = torch.stack([row, col], dim=0)
    if undirected:
        return to_undirected(edge_index, query_pos.size(0))
    return edge_index


def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False):
    return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self,
                           undirected)