import torch import torch.nn as nn import torch.nn.functional as F import dgl from dgl.geometry import farthest_point_sampler """ Part of the code are adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch """ def square_distance(src, dst): """ Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src**2, -1).view(B, N, 1) dist += torch.sum(dst**2, -1).view(B, 1, M) return dist def index_points(points, idx): """ Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = ( torch.arange(B, dtype=torch.long) .to(device) .view(view_shape) .repeat(repeat_shape) ) new_points = points[batch_indices, idx, :] return new_points class KNearNeighbors(nn.Module): """ Find the k nearest neighbors """ def __init__(self, n_neighbor): super(KNearNeighbors, self).__init__() self.n_neighbor = n_neighbor def forward(self, pos, centroids): """ Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch """ center_pos = index_points(pos, centroids) sqrdists = square_distance(center_pos, pos) group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor] return group_idx class KNNGraphBuilder(nn.Module): """ Build NN graph """ def __init__(self, n_neighbor): super(KNNGraphBuilder, self).__init__() self.n_neighbor = n_neighbor self.knn = KNearNeighbors(n_neighbor) def forward(self, pos, centroids, feat=None): dev = pos.device group_idx = self.knn(pos, centroids) B, N, _ = pos.shape glist = [] for i in range(B): center = torch.zeros((N)).to(dev) center[centroids[i]] = 1 src = group_idx[i].contiguous().view(-1) dst = ( centroids[i] .view(-1, 1) .repeat( 1, min(self.n_neighbor, src.shape[0] // centroids.shape[1]) ) .view(-1) ) unified = torch.cat([src, dst]) uniq, inv_idx = torch.unique(unified, return_inverse=True) src_idx = inv_idx[: src.shape[0]] dst_idx = inv_idx[src.shape[0] :] g = dgl.graph((src_idx, dst_idx)) g.ndata["pos"] = pos[i][uniq] g.ndata["center"] = center[uniq] if feat is not None: g.ndata["feat"] = feat[i][uniq] glist.append(g) bg = dgl.batch(glist) return bg class KNNMessage(nn.Module): """ Compute the input feature from neighbors """ def __init__(self, n_neighbor): super(KNNMessage, self).__init__() self.n_neighbor = n_neighbor def forward(self, edges): norm = edges.src["feat"] - edges.dst["feat"] if "feat" in edges.src: res = torch.cat([norm, edges.src["feat"]], 1) else: res = norm return {"agg_feat": res} class KNNConv(nn.Module): """ Feature aggregation """ def __init__(self, sizes): super(KNNConv, self).__init__() self.conv = nn.ModuleList() self.bn = nn.ModuleList() for i in range(1, len(sizes)): self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1)) self.bn.append(nn.BatchNorm2d(sizes[i])) def forward(self, nodes): shape = nodes.mailbox["agg_feat"].shape h = ( nodes.mailbox["agg_feat"] .view(shape[0], -1, shape[1], shape[2]) .permute(0, 3, 2, 1) ) for conv, bn in zip(self.conv, self.bn): h = conv(h) h = bn(h) h = F.relu(h) h = torch.max(h, 2)[0] feat_dim = h.shape[1] h = h.permute(0, 2, 1).reshape(-1, feat_dim) return {"new_feat": h} class TransitionDown(nn.Module): """ The Transition Down Module """ def __init__(self, in_channels, out_channels, n_neighbor=64): super(TransitionDown, self).__init__() self.frnn_graph = KNNGraphBuilder(n_neighbor) self.message = KNNMessage(n_neighbor) self.conv = KNNConv([in_channels, out_channels, out_channels]) def forward(self, pos, feat, n_point): batch_size = pos.shape[0] centroids = farthest_point_sampler(pos, n_point) g = self.frnn_graph(pos, centroids, feat) g.update_all(self.message, self.conv) mask = g.ndata["center"] == 1 pos_dim = g.ndata["pos"].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1] pos_res = g.ndata["pos"][mask].view(batch_size, -1, pos_dim) feat_res = g.ndata["new_feat"][mask].view(batch_size, -1, feat_dim) return pos_res, feat_res