Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
''' """
Part of the code are adapted from Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
def square_distance(src, dst): def square_distance(src, dst):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
view_shape = list(idx.shape) view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1) view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to( batch_indices = (
device).view(view_shape).repeat(repeat_shape) torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points return new_points
class KNearNeighbors(nn.Module): class KNearNeighbors(nn.Module):
''' """
Find the k nearest neighbors Find the k nearest neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__() super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, pos, centroids): def forward(self, pos, centroids):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
center_pos = index_points(pos, centroids) center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos) sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor] group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
return group_idx return group_idx
class KNNGraphBuilder(nn.Module): class KNNGraphBuilder(nn.Module):
''' """
Build NN graph Build NN graph
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__() super(KNNGraphBuilder, self).__init__()
...@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module): ...@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module):
center = torch.zeros((N)).to(dev) center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1 center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1) src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor, dst = (
src.shape[0] // centroids.shape[1])).view(-1) centroids[i]
.view(-1, 1)
.repeat(
1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
)
.view(-1)
)
unified = torch.cat([src, dst]) unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True) uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]] src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0]:] dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx)) g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq] g.ndata["pos"] = pos[i][uniq]
g.ndata['center'] = center[uniq] g.ndata["center"] = center[uniq]
if feat is not None: if feat is not None:
g.ndata['feat'] = feat[i][uniq] g.ndata["feat"] = feat[i][uniq]
glist.append(g) glist.append(g)
bg = dgl.batch(glist) bg = dgl.batch(glist)
return bg return bg
class RelativePositionMessage(nn.Module): class RelativePositionMessage(nn.Module):
''' """
Compute the input feature from neighbors Compute the input feature from neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__() super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, edges): def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos'] pos = edges.src["pos"] - edges.dst["pos"]
if 'feat' in edges.src: if "feat" in edges.src:
res = torch.cat([pos, edges.src['feat']], 1) res = torch.cat([pos, edges.src["feat"]], 1)
else: else:
res = pos res = pos
return {'agg_feat': res} return {"agg_feat": res}
class KNNConv(nn.Module): class KNNConv(nn.Module):
''' """
Feature aggregation Feature aggregation
''' """
def __init__(self, sizes, batch_size): def __init__(self, sizes, batch_size):
super(KNNConv, self).__init__() super(KNNConv, self).__init__()
...@@ -123,13 +134,16 @@ class KNNConv(nn.Module): ...@@ -123,13 +134,16 @@ class KNNConv(nn.Module):
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.bn = nn.ModuleList() self.bn = nn.ModuleList()
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1)) self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i])) self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox["agg_feat"].shape
h = nodes.mailbox['agg_feat'].view( h = (
self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1) nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
...@@ -137,12 +151,12 @@ class KNNConv(nn.Module): ...@@ -137,12 +151,12 @@ class KNNConv(nn.Module):
h = torch.max(h, 2)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {"new_feat": h}
def group_all(self, pos, feat): def group_all(self, pos, feat):
''' """
Feature aggregation and pooling for the non-sampling layer Feature aggregation and pooling for the non-sampling layer
''' """
if feat is not None: if feat is not None:
h = torch.cat([pos, feat], 2) h = torch.cat([pos, feat], 2)
else: else:
...@@ -177,12 +191,11 @@ class TransitionDown(nn.Module): ...@@ -177,12 +191,11 @@ class TransitionDown(nn.Module):
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view( feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
self.batch_size, -1, feat_dim)
return pos_res, feat_res return pos_res, feat_res
...@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module): ...@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module):
sizes = [input_dims] + sizes sizes = [input_dims] + sizes
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1)) self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i])) self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2): def forward(self, x1, x2, feat1, feat2):
...@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module): ...@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module):
dist_recip = 1.0 / (dists + 1e-8) dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True) norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm weight = dist_recip / norm
interpolated_feat = torch.sum(index_points( interpolated_feat = torch.sum(
feat2, idx) * weight.view(B, N, 3, 1), dim=2) index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None: if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1) new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
......
import numpy as np
import torch import torch
from helper import TransitionDown, TransitionUp, index_points, square_distance
from torch import nn from torch import nn
import numpy as np """
from helper import square_distance, index_points, TransitionDown, TransitionUp
'''
Part of the code are adapted from Part of the code are adapted from
https://github.com/qq456cvb/Point-Transformers https://github.com/qq456cvb/Point-Transformers
''' """
class PointTransformerBlock(nn.Module): class PointTransformerBlock(nn.Module):
...@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module): ...@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module):
self.fc_delta = nn.Sequential( self.fc_delta = nn.Sequential(
nn.Linear(3, transformer_dim), nn.Linear(3, transformer_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim) nn.Linear(transformer_dim, transformer_dim),
) )
self.fc_gamma = nn.Sequential( self.fc_gamma = nn.Sequential(
nn.Linear(transformer_dim, transformer_dim), nn.Linear(transformer_dim, transformer_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim) nn.Linear(transformer_dim, transformer_dim),
) )
self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False) self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False) self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False)
...@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module): ...@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module):
def forward(self, x, pos): def forward(self, x, pos):
dists = square_distance(pos, pos) dists = square_distance(pos, pos)
knn_idx = dists.argsort()[:, :, :self.n_neighbors] # b x n x k knn_idx = dists.argsort()[:, :, : self.n_neighbors] # b x n x k
knn_pos = index_points(pos, knn_idx) knn_pos = index_points(pos, knn_idx)
h = self.fc1(x) h = self.fc1(x)
q, k, v = self.w_qs(h), index_points( q, k, v = (
self.w_ks(h), knn_idx), index_points(self.w_vs(h), knn_idx) self.w_qs(h),
index_points(self.w_ks(h), knn_idx),
index_points(self.w_vs(h), knn_idx),
)
pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f
attn = self.fc_gamma(q[:, :, None] - k + pos_enc) attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
attn = torch.softmax(attn / np.sqrt(k.size(-1)), attn = torch.softmax(
dim=-2) # b x n x k x f attn / np.sqrt(k.size(-1)), dim=-2
) # b x n x k x f
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) res = torch.einsum("bmnf,bmnf->bmf", attn, v + pos_enc)
res = self.fc2(res) + x res = self.fc2(res) + x
return res, attn return res, attn
class PointTransformer(nn.Module): class PointTransformer(nn.Module):
def __init__(self, n_points, batch_size, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
n_points,
batch_size,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformer, self).__init__() super(PointTransformer, self).__init__()
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(feature_dim, hidden_dim), nn.Linear(feature_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim) nn.Linear(hidden_dim, hidden_dim),
) )
self.ptb = PointTransformerBlock( self.ptb = PointTransformerBlock(
hidden_dim, n_neighbors, transformer_dim) hidden_dim, n_neighbors, transformer_dim
)
self.transition_downs = nn.ModuleList() self.transition_downs = nn.ModuleList()
self.transformers = nn.ModuleList() self.transformers = nn.ModuleList()
for i in range(n_blocks): for i in range(n_blocks):
block_hidden_dim = hidden_dim * 2 ** (i + 1) block_hidden_dim = hidden_dim * 2 ** (i + 1)
block_n_points = n_points // (downsampling_rate ** (i + 1)) block_n_points = n_points // (downsampling_rate ** (i + 1))
self.transition_downs.append(TransitionDown(block_n_points, batch_size, [ self.transition_downs.append(
block_hidden_dim // 2 + 3, block_hidden_dim, block_hidden_dim], n_neighbors=n_neighbors)) TransitionDown(
block_n_points,
batch_size,
[
block_hidden_dim // 2 + 3,
block_hidden_dim,
block_hidden_dim,
],
n_neighbors=n_neighbors,
)
)
self.transformers.append( self.transformers.append(
PointTransformerBlock(block_hidden_dim, n_neighbors, transformer_dim)) PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
def forward(self, x): def forward(self, x):
if x.shape[-1] > 3: if x.shape[-1] > 3:
...@@ -93,16 +119,35 @@ class PointTransformer(nn.Module): ...@@ -93,16 +119,35 @@ class PointTransformer(nn.Module):
class PointTransformerCLS(nn.Module): class PointTransformerCLS(nn.Module):
def __init__(self, out_classes, batch_size, n_points=1024, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
out_classes,
batch_size,
n_points=1024,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformerCLS, self).__init__() super(PointTransformerCLS, self).__init__()
self.backbone = PointTransformer( self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors) n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.out = self.fc2 = nn.Sequential( self.out = self.fc2 = nn.Sequential(
nn.Linear(hidden_dim * 2 ** (n_blocks), 256), nn.Linear(hidden_dim * 2 ** (n_blocks), 256),
nn.ReLU(), nn.ReLU(),
nn.Linear(256, 64), nn.Linear(256, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, out_classes) nn.Linear(64, out_classes),
) )
def forward(self, x): def forward(self, x):
...@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module): ...@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module):
class PointTransformerSeg(nn.Module): class PointTransformerSeg(nn.Module):
def __init__(self, out_classes, batch_size, n_points=2048, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
out_classes,
batch_size,
n_points=2048,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super().__init__() super().__init__()
self.backbone = PointTransformer( self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors) n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(32 * 2 ** n_blocks, 512), nn.Linear(32 * 2**n_blocks, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 512), nn.Linear(512, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 32 * 2 ** n_blocks) nn.Linear(512, 32 * 2**n_blocks),
) )
self.ptb = PointTransformerBlock( self.ptb = PointTransformerBlock(
32 * 2 ** n_blocks, n_neighbors, transformer_dim) 32 * 2**n_blocks, n_neighbors, transformer_dim
)
self.n_blocks = n_blocks self.n_blocks = n_blocks
self.transition_ups = nn.ModuleList() self.transition_ups = nn.ModuleList()
self.transformers = nn.ModuleList() self.transformers = nn.ModuleList()
for i in reversed(range(n_blocks)): for i in reversed(range(n_blocks)):
block_hidden_dim = 32 * 2 ** i block_hidden_dim = 32 * 2**i
self.transition_ups.append( self.transition_ups.append(
TransitionUp(block_hidden_dim * 2, block_hidden_dim, block_hidden_dim)) TransitionUp(
self.transformers.append(PointTransformerBlock( block_hidden_dim * 2, block_hidden_dim, block_hidden_dim
block_hidden_dim, n_neighbors, transformer_dim)) )
)
self.transformers.append(
PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(32+16, 64), nn.Linear(32 + 16, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, 64), nn.Linear(64, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, out_classes) nn.Linear(64, out_classes),
) )
def forward(self, x, cat_vec=None): def forward(self, x, cat_vec=None):
...@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module): ...@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module):
for i in range(self.n_blocks): for i in range(self.n_blocks):
h = self.transition_ups[i]( h = self.transition_ups[i](
pos, h, hidden_state[- i - 2][0], hidden_state[- i - 2][1]) pos, h, hidden_state[-i - 2][0], hidden_state[-i - 2][1]
pos = hidden_state[- i - 2][0] )
pos = hidden_state[-i - 2][0]
h, _ = self.transformers[i](h, pos) h, _ = self.transformers[i](h, pos)
return self.out(torch.cat([h, cat_vec], dim=-1)) return self.out(torch.cat([h, cat_vec], dim=-1))
......
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
''' """
import numpy as np import numpy as np
def normalize_data(batch_data): def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin, """Normalize the batch data, use coordinates of the block centered at origin,
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C)) normal_data = np.zeros((B, N, C))
...@@ -16,237 +17,296 @@ def normalize_data(batch_data): ...@@ -16,237 +17,296 @@ def normalize_data(batch_data):
pc = batch_data[b] pc = batch_data[b]
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m pc = pc / m
normal_data[b] = pc normal_data[b] = pc
return normal_data return normal_data
def shuffle_data(data, labels): def shuffle_data(data, labels):
""" Shuffle data and labels. """Shuffle data and labels.
Input: Input:
data: B,N,... numpy array data: B,N,... numpy array
label: B,... numpy array label: B,... numpy array
Return: Return:
shuffled data, label and shuffle indices shuffled data, label and shuffle indices
""" """
idx = np.arange(len(labels)) idx = np.arange(len(labels))
np.random.shuffle(idx) np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data): def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior. """Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch. Use the same shuffling idx for the entire batch.
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
idx = np.arange(batch_data.shape[1]) idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx) np.random.shuffle(idx)
return batch_data[:,idx,:] return batch_data[:, idx, :]
def rotate_point_cloud(batch_data): def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_z(batch_data): def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0], rotation_matrix = np.array(
[-sinval, cosval, 0], [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
[0, 0, 1]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal): def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud. """Randomly rotate XYZ, normal point cloud.
Input: Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output: Output:
B,N,6, rotated XYZ, normal point cloud B,N,6, rotated XYZ, normal point cloud
''' """
for k in range(batch_xyz_normal.shape[0]): for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_xyz_normal[k,:,0:3] shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k,:,3:6] shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) batch_xyz_normal[k, :, 0:3] = np.dot(
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations def rotate_perturbation_point_cloud_with_normal(
Input: batch_data, angle_sigma=0.06, angle_clip=0.18
BxNx6 array, original batch of point clouds and point normals ):
Return: """Randomly perturb the point clouds by small rotations
BxNx3 array, rotated batch of point clouds Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
shape_pc = batch_data[k,:,0:3] [
shape_normal = batch_data[k,:,3:6] [np.cos(angles[1]), 0, np.sin(angles[1])],
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) [0, 1, 0],
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) [-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle): def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx6 array, original batch of point clouds with normal BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation scalar, angle of rotation
Return: Return:
BxNx6 array, rotated batch of point clouds iwth normal BxNx6 array, rotated batch of point clouds iwth normal
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k,:,3:6] shape_normal = batch_data[k, :, 3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_perturbation_point_cloud(
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): batch_data, angle_sigma=0.06, angle_clip=0.18
""" Randomly perturb the point clouds by small rotations ):
Input: """Randomly perturb the point clouds by small rotations
BxNx3 array, original batch of point clouds Input:
Return: BxNx3 array, original batch of point clouds
BxNx3 array, rotated batch of point clouds Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point. """Randomly jitter points. jittering is per point.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, jittered batch of point clouds BxNx3 array, jittered batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
assert(clip > 0) assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data jittered_data += batch_data
return jittered_data return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1): def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud. """Randomly shift point cloud. Shift is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, shifted batch of point clouds BxNx3 array, shifted batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3)) shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:] batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud. """Randomly scale the point cloud. Scale is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, scaled batch of point clouds BxNx3 array, scaled batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B) scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index] batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875): def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 ''' """batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]): for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] drop_idx = np.where(
if len(drop_idx)>0: np.random.random((batch_pc.shape[1])) <= dropout_ratio
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need )[0]
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc return batch_pc
import numpy as np
import warnings
import os import os
import warnings
import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc): def pc_normalize(pc):
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
...@@ -11,6 +14,7 @@ def pc_normalize(pc): ...@@ -11,6 +14,7 @@ def pc_normalize(pc):
pc = pc / m pc = pc / m
return pc return pc
def farthest_point_sample(point, npoint): def farthest_point_sample(point, npoint):
""" """
Farthest point sampler works as follows: Farthest point sampler works as follows:
...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint): ...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint):
centroids: sampled pointcloud index, [npoint, D] centroids: sampled pointcloud index, [npoint, D]
""" """
N, D = point.shape N, D = point.shape
xyz = point[:,:3] xyz = point[:, :3]
centroids = np.zeros((npoint,)) centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10 distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N) farthest = np.random.randint(0, N)
...@@ -39,12 +43,20 @@ def farthest_point_sample(point, npoint): ...@@ -39,12 +43,20 @@ def farthest_point_sample(point, npoint):
point = point[centroids.astype(np.int32)] point = point[centroids.astype(np.int32)]
return point return point
class ModelNetDataLoader(Dataset): class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False, def __init__(
normal_channel=True, cache_size=15000): self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
""" """
Input: Input:
root: the root path to the local data files root: the root path to the local data files
npoint: number of points from each cloud npoint: number of points from each cloud
split: which split of the data, 'train' or 'test' split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler fps: whether to sample points with farthest point sampler
...@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset): ...@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root self.root = root
self.npoints = npoint self.npoints = npoint
self.fps = fps self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)] self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel self.normal_channel = normal_channel
shape_ids = {} shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] shape_ids["train"] = [
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert (split == 'train' or split == 'test') assert split == "train" or split == "test"
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple # list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i self.datapath = [
in range(len(shape_ids[split]))] (
print('The size of %s data is %d'%(split,len(self.datapath))) shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size self.cache_size = cache_size
self.cache = {} self.cache = {}
...@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset): ...@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index] fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]] cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32) cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps: if self.fps:
point_set = farthest_point_sample(point_set, self.npoints) point_set = farthest_point_sample(point_set, self.npoints)
else: else:
point_set = point_set[0:self.npoints,:] point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
...@@ -101,4 +125,4 @@ class ModelNetDataLoader(Dataset): ...@@ -101,4 +125,4 @@ class ModelNetDataLoader(Dataset):
return point_set, cls return point_set, cls
def __getitem__(self, index): def __getitem__(self, index):
return self._get_item(index) return self._get_item(index)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment