Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
''' """
Part of the code are adapted from Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
def square_distance(src, dst): def square_distance(src, dst):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
view_shape = list(idx.shape) view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1) view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to( batch_indices = (
device).view(view_shape).repeat(repeat_shape) torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points return new_points
class KNearNeighbors(nn.Module): class KNearNeighbors(nn.Module):
''' """
Find the k nearest neighbors Find the k nearest neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__() super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, pos, centroids): def forward(self, pos, centroids):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
center_pos = index_points(pos, centroids) center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos) sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor] group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
return group_idx return group_idx
class KNNGraphBuilder(nn.Module): class KNNGraphBuilder(nn.Module):
''' """
Build NN graph Build NN graph
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__() super(KNNGraphBuilder, self).__init__()
...@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module): ...@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module):
center = torch.zeros((N)).to(dev) center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1 center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1) src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor, dst = (
src.shape[0] // centroids.shape[1])).view(-1) centroids[i]
.view(-1, 1)
.repeat(
1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
)
.view(-1)
)
unified = torch.cat([src, dst]) unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True) uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]] src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0]:] dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx)) g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq] g.ndata["pos"] = pos[i][uniq]
g.ndata['center'] = center[uniq] g.ndata["center"] = center[uniq]
if feat is not None: if feat is not None:
g.ndata['feat'] = feat[i][uniq] g.ndata["feat"] = feat[i][uniq]
glist.append(g) glist.append(g)
bg = dgl.batch(glist) bg = dgl.batch(glist)
return bg return bg
class RelativePositionMessage(nn.Module): class RelativePositionMessage(nn.Module):
''' """
Compute the input feature from neighbors Compute the input feature from neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__() super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, edges): def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos'] pos = edges.src["pos"] - edges.dst["pos"]
if 'feat' in edges.src: if "feat" in edges.src:
res = torch.cat([pos, edges.src['feat']], 1) res = torch.cat([pos, edges.src["feat"]], 1)
else: else:
res = pos res = pos
return {'agg_feat': res} return {"agg_feat": res}
class KNNConv(nn.Module): class KNNConv(nn.Module):
''' """
Feature aggregation Feature aggregation
''' """
def __init__(self, sizes, batch_size): def __init__(self, sizes, batch_size):
super(KNNConv, self).__init__() super(KNNConv, self).__init__()
...@@ -123,13 +134,16 @@ class KNNConv(nn.Module): ...@@ -123,13 +134,16 @@ class KNNConv(nn.Module):
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.bn = nn.ModuleList() self.bn = nn.ModuleList()
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1)) self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i])) self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox["agg_feat"].shape
h = nodes.mailbox['agg_feat'].view( h = (
self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1) nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
...@@ -137,12 +151,12 @@ class KNNConv(nn.Module): ...@@ -137,12 +151,12 @@ class KNNConv(nn.Module):
h = torch.max(h, 2)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {"new_feat": h}
def group_all(self, pos, feat): def group_all(self, pos, feat):
''' """
Feature aggregation and pooling for the non-sampling layer Feature aggregation and pooling for the non-sampling layer
''' """
if feat is not None: if feat is not None:
h = torch.cat([pos, feat], 2) h = torch.cat([pos, feat], 2)
else: else:
...@@ -177,12 +191,11 @@ class TransitionDown(nn.Module): ...@@ -177,12 +191,11 @@ class TransitionDown(nn.Module):
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view( feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
self.batch_size, -1, feat_dim)
return pos_res, feat_res return pos_res, feat_res
...@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module): ...@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module):
sizes = [input_dims] + sizes sizes = [input_dims] + sizes
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1)) self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i])) self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2): def forward(self, x1, x2, feat1, feat2):
...@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module): ...@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module):
dist_recip = 1.0 / (dists + 1e-8) dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True) norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm weight = dist_recip / norm
interpolated_feat = torch.sum(index_points( interpolated_feat = torch.sum(
feat2, idx) * weight.view(B, N, 3, 1), dim=2) index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None: if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1) new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
......
import numpy as np
import torch import torch
from helper import TransitionDown, TransitionUp, index_points, square_distance
from torch import nn from torch import nn
import numpy as np """
from helper import square_distance, index_points, TransitionDown, TransitionUp
'''
Part of the code are adapted from Part of the code are adapted from
https://github.com/qq456cvb/Point-Transformers https://github.com/qq456cvb/Point-Transformers
''' """
class PointTransformerBlock(nn.Module): class PointTransformerBlock(nn.Module):
...@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module): ...@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module):
self.fc_delta = nn.Sequential( self.fc_delta = nn.Sequential(
nn.Linear(3, transformer_dim), nn.Linear(3, transformer_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim) nn.Linear(transformer_dim, transformer_dim),
) )
self.fc_gamma = nn.Sequential( self.fc_gamma = nn.Sequential(
nn.Linear(transformer_dim, transformer_dim), nn.Linear(transformer_dim, transformer_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim) nn.Linear(transformer_dim, transformer_dim),
) )
self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False) self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False) self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False)
...@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module): ...@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module):
def forward(self, x, pos): def forward(self, x, pos):
dists = square_distance(pos, pos) dists = square_distance(pos, pos)
knn_idx = dists.argsort()[:, :, :self.n_neighbors] # b x n x k knn_idx = dists.argsort()[:, :, : self.n_neighbors] # b x n x k
knn_pos = index_points(pos, knn_idx) knn_pos = index_points(pos, knn_idx)
h = self.fc1(x) h = self.fc1(x)
q, k, v = self.w_qs(h), index_points( q, k, v = (
self.w_ks(h), knn_idx), index_points(self.w_vs(h), knn_idx) self.w_qs(h),
index_points(self.w_ks(h), knn_idx),
index_points(self.w_vs(h), knn_idx),
)
pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f
attn = self.fc_gamma(q[:, :, None] - k + pos_enc) attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
attn = torch.softmax(attn / np.sqrt(k.size(-1)), attn = torch.softmax(
dim=-2) # b x n x k x f attn / np.sqrt(k.size(-1)), dim=-2
) # b x n x k x f
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) res = torch.einsum("bmnf,bmnf->bmf", attn, v + pos_enc)
res = self.fc2(res) + x res = self.fc2(res) + x
return res, attn return res, attn
class PointTransformer(nn.Module): class PointTransformer(nn.Module):
def __init__(self, n_points, batch_size, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
n_points,
batch_size,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformer, self).__init__() super(PointTransformer, self).__init__()
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(feature_dim, hidden_dim), nn.Linear(feature_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim) nn.Linear(hidden_dim, hidden_dim),
) )
self.ptb = PointTransformerBlock( self.ptb = PointTransformerBlock(
hidden_dim, n_neighbors, transformer_dim) hidden_dim, n_neighbors, transformer_dim
)
self.transition_downs = nn.ModuleList() self.transition_downs = nn.ModuleList()
self.transformers = nn.ModuleList() self.transformers = nn.ModuleList()
for i in range(n_blocks): for i in range(n_blocks):
block_hidden_dim = hidden_dim * 2 ** (i + 1) block_hidden_dim = hidden_dim * 2 ** (i + 1)
block_n_points = n_points // (downsampling_rate ** (i + 1)) block_n_points = n_points // (downsampling_rate ** (i + 1))
self.transition_downs.append(TransitionDown(block_n_points, batch_size, [ self.transition_downs.append(
block_hidden_dim // 2 + 3, block_hidden_dim, block_hidden_dim], n_neighbors=n_neighbors)) TransitionDown(
block_n_points,
batch_size,
[
block_hidden_dim // 2 + 3,
block_hidden_dim,
block_hidden_dim,
],
n_neighbors=n_neighbors,
)
)
self.transformers.append( self.transformers.append(
PointTransformerBlock(block_hidden_dim, n_neighbors, transformer_dim)) PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
def forward(self, x): def forward(self, x):
if x.shape[-1] > 3: if x.shape[-1] > 3:
...@@ -93,16 +119,35 @@ class PointTransformer(nn.Module): ...@@ -93,16 +119,35 @@ class PointTransformer(nn.Module):
class PointTransformerCLS(nn.Module): class PointTransformerCLS(nn.Module):
def __init__(self, out_classes, batch_size, n_points=1024, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
out_classes,
batch_size,
n_points=1024,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformerCLS, self).__init__() super(PointTransformerCLS, self).__init__()
self.backbone = PointTransformer( self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors) n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.out = self.fc2 = nn.Sequential( self.out = self.fc2 = nn.Sequential(
nn.Linear(hidden_dim * 2 ** (n_blocks), 256), nn.Linear(hidden_dim * 2 ** (n_blocks), 256),
nn.ReLU(), nn.ReLU(),
nn.Linear(256, 64), nn.Linear(256, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, out_classes) nn.Linear(64, out_classes),
) )
def forward(self, x): def forward(self, x):
...@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module): ...@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module):
class PointTransformerSeg(nn.Module): class PointTransformerSeg(nn.Module):
def __init__(self, out_classes, batch_size, n_points=2048, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
out_classes,
batch_size,
n_points=2048,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super().__init__() super().__init__()
self.backbone = PointTransformer( self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors) n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(32 * 2 ** n_blocks, 512), nn.Linear(32 * 2**n_blocks, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 512), nn.Linear(512, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 32 * 2 ** n_blocks) nn.Linear(512, 32 * 2**n_blocks),
) )
self.ptb = PointTransformerBlock( self.ptb = PointTransformerBlock(
32 * 2 ** n_blocks, n_neighbors, transformer_dim) 32 * 2**n_blocks, n_neighbors, transformer_dim
)
self.n_blocks = n_blocks self.n_blocks = n_blocks
self.transition_ups = nn.ModuleList() self.transition_ups = nn.ModuleList()
self.transformers = nn.ModuleList() self.transformers = nn.ModuleList()
for i in reversed(range(n_blocks)): for i in reversed(range(n_blocks)):
block_hidden_dim = 32 * 2 ** i block_hidden_dim = 32 * 2**i
self.transition_ups.append( self.transition_ups.append(
TransitionUp(block_hidden_dim * 2, block_hidden_dim, block_hidden_dim)) TransitionUp(
self.transformers.append(PointTransformerBlock( block_hidden_dim * 2, block_hidden_dim, block_hidden_dim
block_hidden_dim, n_neighbors, transformer_dim)) )
)
self.transformers.append(
PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(32+16, 64), nn.Linear(32 + 16, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, 64), nn.Linear(64, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, out_classes) nn.Linear(64, out_classes),
) )
def forward(self, x, cat_vec=None): def forward(self, x, cat_vec=None):
...@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module): ...@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module):
for i in range(self.n_blocks): for i in range(self.n_blocks):
h = self.transition_ups[i]( h = self.transition_ups[i](
pos, h, hidden_state[- i - 2][0], hidden_state[- i - 2][1]) pos, h, hidden_state[-i - 2][0], hidden_state[-i - 2][1]
pos = hidden_state[- i - 2][0] )
pos = hidden_state[-i - 2][0]
h, _ = self.transformers[i](h, pos) h, _ = self.transformers[i](h, pos)
return self.out(torch.cat([h, cat_vec], dim=-1)) return self.out(torch.cat([h, cat_vec], dim=-1))
......
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
''' """
import numpy as np import numpy as np
def normalize_data(batch_data): def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin, """Normalize the batch data, use coordinates of the block centered at origin,
Input: Input:
BxNxC array BxNxC array
Output: Output:
...@@ -16,14 +17,14 @@ def normalize_data(batch_data): ...@@ -16,14 +17,14 @@ def normalize_data(batch_data):
pc = batch_data[b] pc = batch_data[b]
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m pc = pc / m
normal_data[b] = pc normal_data[b] = pc
return normal_data return normal_data
def shuffle_data(data, labels): def shuffle_data(data, labels):
""" Shuffle data and labels. """Shuffle data and labels.
Input: Input:
data: B,N,... numpy array data: B,N,... numpy array
label: B,... numpy array label: B,... numpy array
...@@ -34,8 +35,9 @@ def shuffle_data(data, labels): ...@@ -34,8 +35,9 @@ def shuffle_data(data, labels):
np.random.shuffle(idx) np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data): def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior. """Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch. Use the same shuffling idx for the entire batch.
Input: Input:
BxNxC array BxNxC array
...@@ -44,10 +46,11 @@ def shuffle_points(batch_data): ...@@ -44,10 +46,11 @@ def shuffle_points(batch_data):
""" """
idx = np.arange(batch_data.shape[1]) idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx) np.random.shuffle(idx)
return batch_data[:,idx,:] return batch_data[:, idx, :]
def rotate_point_cloud(batch_data): def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
...@@ -59,15 +62,18 @@ def rotate_point_cloud(batch_data): ...@@ -59,15 +62,18 @@ def rotate_point_cloud(batch_data):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_z(batch_data): def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
...@@ -79,35 +85,45 @@ def rotate_point_cloud_z(batch_data): ...@@ -79,35 +85,45 @@ def rotate_point_cloud_z(batch_data):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0], rotation_matrix = np.array(
[-sinval, cosval, 0], [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
[0, 0, 1]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal): def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud. """Randomly rotate XYZ, normal point cloud.
Input: Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output: Output:
B,N,6, rotated XYZ, normal point cloud B,N,6, rotated XYZ, normal point cloud
''' """
for k in range(batch_xyz_normal.shape[0]): for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_xyz_normal[k,:,0:3] shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k,:,3:6] shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) batch_xyz_normal[k, :, 0:3] = np.dot(
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations def rotate_perturbation_point_cloud_with_normal(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input: Input:
BxNx6 array, original batch of point clouds and point normals BxNx6 array, original batch of point clouds and point normals
Return: Return:
...@@ -115,26 +131,40 @@ def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, an ...@@ -115,26 +131,40 @@ def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, an
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
shape_pc = batch_data[k,:,0:3] [
shape_normal = batch_data[k,:,3:6] [np.cos(angles[1]), 0, np.sin(angles[1])],
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) [0, 1, 0],
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) [-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle): def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
...@@ -142,18 +172,21 @@ def rotate_point_cloud_by_angle(batch_data, rotation_angle): ...@@ -142,18 +172,21 @@ def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx6 array, original batch of point clouds with normal BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation scalar, angle of rotation
...@@ -162,22 +195,27 @@ def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): ...@@ -162,22 +195,27 @@ def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k,:,3:6] shape_normal = batch_data[k, :, 3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_perturbation_point_cloud(
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): batch_data, angle_sigma=0.06, angle_clip=0.18
""" Randomly perturb the point clouds by small rotations ):
"""Randomly perturb the point clouds by small rotations
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
...@@ -185,51 +223,66 @@ def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.1 ...@@ -185,51 +223,66 @@ def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.1
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point. """Randomly jitter points. jittering is per point.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, jittered batch of point clouds BxNx3 array, jittered batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
assert(clip > 0) assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data jittered_data += batch_data
return jittered_data return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1): def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud. """Randomly shift point cloud. Shift is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, shifted batch of point clouds BxNx3 array, shifted batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3)) shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:] batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud. """Randomly scale the point cloud. Scale is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
...@@ -238,15 +291,22 @@ def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): ...@@ -238,15 +291,22 @@ def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
B, N, C = batch_data.shape B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B) scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index] batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875): def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 ''' """batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]): for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] drop_idx = np.where(
if len(drop_idx)>0: np.random.random((batch_pc.shape[1])) <= dropout_ratio
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need )[0]
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc return batch_pc
from point_transformer import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse import argparse
import os import os
import tqdm import time
from functools import partial from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader import provider
import torch.nn as nn
import torch import torch
import time import torch.nn as nn
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from point_transformer import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=200) parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8) parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument('--opt', type=str, default='adam') parser.add_argument("--opt", type=str, default="adam")
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip' data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename) download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join( local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled') get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path): if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download(
download_path, verify_ssl=False) "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile from zipfile import ZipFile
with ZipFile(download_path) as z: with ZipFile(download_path) as z:
z.extractall(path=get_download_dir()) z.extractall(path=get_download_dir())
...@@ -43,7 +50,8 @@ CustomDataLoader = partial( ...@@ -43,7 +50,8 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
...@@ -60,8 +68,7 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -60,8 +68,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq: for data, label in tq:
data = data.data.numpy() data = data.data.numpy()
data = provider.random_point_dropout(data) data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud( data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data) data = torch.tensor(data)
...@@ -84,11 +91,19 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -84,11 +91,19 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss total_loss += loss
total_correct += correct total_correct += correct
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % (total_loss / num_batches), {
'AvgAcc': '%.5f' % (total_correct / count)}) "AvgLoss": "%.5f" % (total_loss / num_batches),
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / "AvgAcc": "%.5f" % (total_correct / count),
num_batches, total_correct / count, time.time() - start_time)) }
)
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches,
total_correct / count,
time.time() - start_time,
)
)
scheduler.step() scheduler.step()
...@@ -111,10 +126,12 @@ def evaluate(net, test_loader, dev): ...@@ -111,10 +126,12 @@ def evaluate(net, test_loader, dev):
total_correct += correct total_correct += correct
count += num_examples count += num_examples
tq.set_postfix({ tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
'AvgAcc': '%.5f' % (total_correct / count)}) print(
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format( "[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time)) total_correct / count, time.time() - start_time
)
)
return total_correct / count return total_correct / count
...@@ -125,13 +142,15 @@ net = net.to(dev) ...@@ -125,13 +142,15 @@ net = net.to(dev)
if args.load_model_path: if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'sgd': if args.opt == "sgd":
# The optimizer strategy described in paper: # The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01, opt = torch.optim.SGD(
momentum=0.9, weight_decay=1e-4) net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1) opt, milestones=[120, 160], gamma=0.1
elif args.opt == 'adam': )
elif args.opt == "adam":
# The optimizer strategy proposed by # The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers: # https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam( opt = torch.optim.Adam(
...@@ -139,16 +158,26 @@ elif args.opt == 'adam': ...@@ -139,16 +158,26 @@ elif args.opt == 'adam':
lr=1e-3, lr=1e-3,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-08, eps=1e-08,
weight_decay=1e-4 weight_decay=1e-4,
) )
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3) scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train') train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split='test') test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0 best_test_acc = 0
...@@ -161,6 +190,5 @@ for epoch in range(args.num_epochs): ...@@ -161,6 +190,5 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc best_test_acc = test_acc
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % ( print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
test_acc, best_test_acc))
print() print()
import numpy as np
import warnings
import os import os
import warnings
import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc): def pc_normalize(pc):
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
...@@ -11,6 +14,7 @@ def pc_normalize(pc): ...@@ -11,6 +14,7 @@ def pc_normalize(pc):
pc = pc / m pc = pc / m
return pc return pc
def farthest_point_sample(point, npoint): def farthest_point_sample(point, npoint):
""" """
Farthest point sampler works as follows: Farthest point sampler works as follows:
...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint): ...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint):
centroids: sampled pointcloud index, [npoint, D] centroids: sampled pointcloud index, [npoint, D]
""" """
N, D = point.shape N, D = point.shape
xyz = point[:,:3] xyz = point[:, :3]
centroids = np.zeros((npoint,)) centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10 distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N) farthest = np.random.randint(0, N)
...@@ -39,9 +43,17 @@ def farthest_point_sample(point, npoint): ...@@ -39,9 +43,17 @@ def farthest_point_sample(point, npoint):
point = point[centroids.astype(np.int32)] point = point[centroids.astype(np.int32)]
return point return point
class ModelNetDataLoader(Dataset): class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False, def __init__(
normal_channel=True, cache_size=15000): self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
""" """
Input: Input:
root: the root path to the local data files root: the root path to the local data files
...@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset): ...@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root self.root = root
self.npoints = npoint self.npoints = npoint
self.fps = fps self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)] self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel self.normal_channel = normal_channel
shape_ids = {} shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] shape_ids["train"] = [
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert (split == 'train' or split == 'test') assert split == "train" or split == "test"
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple # list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i self.datapath = [
in range(len(shape_ids[split]))] (
print('The size of %s data is %d'%(split,len(self.datapath))) shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size self.cache_size = cache_size
self.cache = {} self.cache = {}
...@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset): ...@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index] fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]] cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32) cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps: if self.fps:
point_set = farthest_point_sample(point_set, self.npoints) point_set = farthest_point_sample(point_set, self.npoints)
else: else:
point_set = point_set[0:self.npoints,:] point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
class PointNetCls(nn.Module): class PointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64, def __init__(
dropout_prob=0.5, use_transform=True): self,
output_classes,
input_dims=3,
conv1_dim=64,
dropout_prob=0.5,
use_transform=True,
):
super(PointNetCls, self).__init__() super(PointNetCls, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
self.conv1 = nn.ModuleList() self.conv1 = nn.ModuleList()
...@@ -85,6 +92,7 @@ class PointNetCls(nn.Module): ...@@ -85,6 +92,7 @@ class PointNetCls(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
return out return out
class TransformNet(nn.Module): class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64): def __init__(self, input_dims=3, conv1_dim=64):
super(TransformNet, self).__init__() super(TransformNet, self).__init__()
...@@ -127,8 +135,14 @@ class TransformNet(nn.Module): ...@@ -127,8 +135,14 @@ class TransformNet(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32))) iden = Variable(
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1) torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda: if out.is_cuda:
iden = iden.cuda() iden = iden.cuda()
out = out + iden out = out + iden
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout): def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment