"docs/vscode:/vscode.git/clone" did not exist on "4f48476dd6336f35489378cf38c0852a48f92289"
Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
''' """
Part of the code are adapted from Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
def square_distance(src, dst): def square_distance(src, dst):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
view_shape = list(idx.shape) view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1) view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to( batch_indices = (
device).view(view_shape).repeat(repeat_shape) torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points return new_points
class KNearNeighbors(nn.Module): class KNearNeighbors(nn.Module):
''' """
Find the k nearest neighbors Find the k nearest neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__() super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, pos, centroids): def forward(self, pos, centroids):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
center_pos = index_points(pos, centroids) center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos) sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor] group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
return group_idx return group_idx
class KNNGraphBuilder(nn.Module): class KNNGraphBuilder(nn.Module):
''' """
Build NN graph Build NN graph
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__() super(KNNGraphBuilder, self).__init__()
...@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module): ...@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module):
center = torch.zeros((N)).to(dev) center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1 center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1) src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor, dst = (
src.shape[0] // centroids.shape[1])).view(-1) centroids[i]
.view(-1, 1)
.repeat(
1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
)
.view(-1)
)
unified = torch.cat([src, dst]) unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True) uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]] src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0]:] dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx)) g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq] g.ndata["pos"] = pos[i][uniq]
g.ndata['center'] = center[uniq] g.ndata["center"] = center[uniq]
if feat is not None: if feat is not None:
g.ndata['feat'] = feat[i][uniq] g.ndata["feat"] = feat[i][uniq]
glist.append(g) glist.append(g)
bg = dgl.batch(glist) bg = dgl.batch(glist)
return bg return bg
class RelativePositionMessage(nn.Module): class RelativePositionMessage(nn.Module):
''' """
Compute the input feature from neighbors Compute the input feature from neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__() super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, edges): def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos'] pos = edges.src["pos"] - edges.dst["pos"]
if 'feat' in edges.src: if "feat" in edges.src:
res = torch.cat([pos, edges.src['feat']], 1) res = torch.cat([pos, edges.src["feat"]], 1)
else: else:
res = pos res = pos
return {'agg_feat': res} return {"agg_feat": res}
class KNNConv(nn.Module): class KNNConv(nn.Module):
''' """
Feature aggregation Feature aggregation
''' """
def __init__(self, sizes, batch_size): def __init__(self, sizes, batch_size):
super(KNNConv, self).__init__() super(KNNConv, self).__init__()
...@@ -123,13 +134,16 @@ class KNNConv(nn.Module): ...@@ -123,13 +134,16 @@ class KNNConv(nn.Module):
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.bn = nn.ModuleList() self.bn = nn.ModuleList()
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1)) self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i])) self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox["agg_feat"].shape
h = nodes.mailbox['agg_feat'].view( h = (
self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1) nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
...@@ -137,12 +151,12 @@ class KNNConv(nn.Module): ...@@ -137,12 +151,12 @@ class KNNConv(nn.Module):
h = torch.max(h, 2)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {"new_feat": h}
def group_all(self, pos, feat): def group_all(self, pos, feat):
''' """
Feature aggregation and pooling for the non-sampling layer Feature aggregation and pooling for the non-sampling layer
''' """
if feat is not None: if feat is not None:
h = torch.cat([pos, feat], 2) h = torch.cat([pos, feat], 2)
else: else:
...@@ -177,12 +191,11 @@ class TransitionDown(nn.Module): ...@@ -177,12 +191,11 @@ class TransitionDown(nn.Module):
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view( feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
self.batch_size, -1, feat_dim)
return pos_res, feat_res return pos_res, feat_res
...@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module): ...@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module):
sizes = [input_dims] + sizes sizes = [input_dims] + sizes
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1)) self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i])) self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2): def forward(self, x1, x2, feat1, feat2):
...@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module): ...@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module):
dist_recip = 1.0 / (dists + 1e-8) dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True) norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm weight = dist_recip / norm
interpolated_feat = torch.sum(index_points( interpolated_feat = torch.sum(
feat2, idx) * weight.view(B, N, 3, 1), dim=2) index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None: if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1) new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
......
import numpy as np
import torch import torch
from helper import TransitionDown, TransitionUp, index_points, square_distance
from torch import nn from torch import nn
import numpy as np """
from helper import square_distance, index_points, TransitionDown, TransitionUp
'''
Part of the code are adapted from Part of the code are adapted from
https://github.com/qq456cvb/Point-Transformers https://github.com/qq456cvb/Point-Transformers
''' """
class PointTransformerBlock(nn.Module): class PointTransformerBlock(nn.Module):
...@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module): ...@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module):
self.fc_delta = nn.Sequential( self.fc_delta = nn.Sequential(
nn.Linear(3, transformer_dim), nn.Linear(3, transformer_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim) nn.Linear(transformer_dim, transformer_dim),
) )
self.fc_gamma = nn.Sequential( self.fc_gamma = nn.Sequential(
nn.Linear(transformer_dim, transformer_dim), nn.Linear(transformer_dim, transformer_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim) nn.Linear(transformer_dim, transformer_dim),
) )
self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False) self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False) self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False)
...@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module): ...@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module):
def forward(self, x, pos): def forward(self, x, pos):
dists = square_distance(pos, pos) dists = square_distance(pos, pos)
knn_idx = dists.argsort()[:, :, :self.n_neighbors] # b x n x k knn_idx = dists.argsort()[:, :, : self.n_neighbors] # b x n x k
knn_pos = index_points(pos, knn_idx) knn_pos = index_points(pos, knn_idx)
h = self.fc1(x) h = self.fc1(x)
q, k, v = self.w_qs(h), index_points( q, k, v = (
self.w_ks(h), knn_idx), index_points(self.w_vs(h), knn_idx) self.w_qs(h),
index_points(self.w_ks(h), knn_idx),
index_points(self.w_vs(h), knn_idx),
)
pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f
attn = self.fc_gamma(q[:, :, None] - k + pos_enc) attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
attn = torch.softmax(attn / np.sqrt(k.size(-1)), attn = torch.softmax(
dim=-2) # b x n x k x f attn / np.sqrt(k.size(-1)), dim=-2
) # b x n x k x f
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) res = torch.einsum("bmnf,bmnf->bmf", attn, v + pos_enc)
res = self.fc2(res) + x res = self.fc2(res) + x
return res, attn return res, attn
class PointTransformer(nn.Module): class PointTransformer(nn.Module):
def __init__(self, n_points, batch_size, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
n_points,
batch_size,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformer, self).__init__() super(PointTransformer, self).__init__()
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(feature_dim, hidden_dim), nn.Linear(feature_dim, hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim) nn.Linear(hidden_dim, hidden_dim),
) )
self.ptb = PointTransformerBlock( self.ptb = PointTransformerBlock(
hidden_dim, n_neighbors, transformer_dim) hidden_dim, n_neighbors, transformer_dim
)
self.transition_downs = nn.ModuleList() self.transition_downs = nn.ModuleList()
self.transformers = nn.ModuleList() self.transformers = nn.ModuleList()
for i in range(n_blocks): for i in range(n_blocks):
block_hidden_dim = hidden_dim * 2 ** (i + 1) block_hidden_dim = hidden_dim * 2 ** (i + 1)
block_n_points = n_points // (downsampling_rate ** (i + 1)) block_n_points = n_points // (downsampling_rate ** (i + 1))
self.transition_downs.append(TransitionDown(block_n_points, batch_size, [ self.transition_downs.append(
block_hidden_dim // 2 + 3, block_hidden_dim, block_hidden_dim], n_neighbors=n_neighbors)) TransitionDown(
block_n_points,
batch_size,
[
block_hidden_dim // 2 + 3,
block_hidden_dim,
block_hidden_dim,
],
n_neighbors=n_neighbors,
)
)
self.transformers.append( self.transformers.append(
PointTransformerBlock(block_hidden_dim, n_neighbors, transformer_dim)) PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
def forward(self, x): def forward(self, x):
if x.shape[-1] > 3: if x.shape[-1] > 3:
...@@ -93,16 +119,35 @@ class PointTransformer(nn.Module): ...@@ -93,16 +119,35 @@ class PointTransformer(nn.Module):
class PointTransformerCLS(nn.Module): class PointTransformerCLS(nn.Module):
def __init__(self, out_classes, batch_size, n_points=1024, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
out_classes,
batch_size,
n_points=1024,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformerCLS, self).__init__() super(PointTransformerCLS, self).__init__()
self.backbone = PointTransformer( self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors) n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.out = self.fc2 = nn.Sequential( self.out = self.fc2 = nn.Sequential(
nn.Linear(hidden_dim * 2 ** (n_blocks), 256), nn.Linear(hidden_dim * 2 ** (n_blocks), 256),
nn.ReLU(), nn.ReLU(),
nn.Linear(256, 64), nn.Linear(256, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, out_classes) nn.Linear(64, out_classes),
) )
def forward(self, x): def forward(self, x):
...@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module): ...@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module):
class PointTransformerSeg(nn.Module): class PointTransformerSeg(nn.Module):
def __init__(self, out_classes, batch_size, n_points=2048, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16): def __init__(
self,
out_classes,
batch_size,
n_points=2048,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super().__init__() super().__init__()
self.backbone = PointTransformer( self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors) n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(32 * 2 ** n_blocks, 512), nn.Linear(32 * 2**n_blocks, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 512), nn.Linear(512, 512),
nn.ReLU(), nn.ReLU(),
nn.Linear(512, 32 * 2 ** n_blocks) nn.Linear(512, 32 * 2**n_blocks),
) )
self.ptb = PointTransformerBlock( self.ptb = PointTransformerBlock(
32 * 2 ** n_blocks, n_neighbors, transformer_dim) 32 * 2**n_blocks, n_neighbors, transformer_dim
)
self.n_blocks = n_blocks self.n_blocks = n_blocks
self.transition_ups = nn.ModuleList() self.transition_ups = nn.ModuleList()
self.transformers = nn.ModuleList() self.transformers = nn.ModuleList()
for i in reversed(range(n_blocks)): for i in reversed(range(n_blocks)):
block_hidden_dim = 32 * 2 ** i block_hidden_dim = 32 * 2**i
self.transition_ups.append( self.transition_ups.append(
TransitionUp(block_hidden_dim * 2, block_hidden_dim, block_hidden_dim)) TransitionUp(
self.transformers.append(PointTransformerBlock( block_hidden_dim * 2, block_hidden_dim, block_hidden_dim
block_hidden_dim, n_neighbors, transformer_dim)) )
)
self.transformers.append(
PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(32+16, 64), nn.Linear(32 + 16, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, 64), nn.Linear(64, 64),
nn.ReLU(), nn.ReLU(),
nn.Linear(64, out_classes) nn.Linear(64, out_classes),
) )
def forward(self, x, cat_vec=None): def forward(self, x, cat_vec=None):
...@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module): ...@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module):
for i in range(self.n_blocks): for i in range(self.n_blocks):
h = self.transition_ups[i]( h = self.transition_ups[i](
pos, h, hidden_state[- i - 2][0], hidden_state[- i - 2][1]) pos, h, hidden_state[-i - 2][0], hidden_state[-i - 2][1]
pos = hidden_state[- i - 2][0] )
pos = hidden_state[-i - 2][0]
h, _ = self.transformers[i](h, pos) h, _ = self.transformers[i](h, pos)
return self.out(torch.cat([h, cat_vec], dim=-1)) return self.out(torch.cat([h, cat_vec], dim=-1))
......
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
''' """
import numpy as np import numpy as np
def normalize_data(batch_data): def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin, """Normalize the batch data, use coordinates of the block centered at origin,
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C)) normal_data = np.zeros((B, N, C))
...@@ -16,237 +17,296 @@ def normalize_data(batch_data): ...@@ -16,237 +17,296 @@ def normalize_data(batch_data):
pc = batch_data[b] pc = batch_data[b]
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m pc = pc / m
normal_data[b] = pc normal_data[b] = pc
return normal_data return normal_data
def shuffle_data(data, labels): def shuffle_data(data, labels):
""" Shuffle data and labels. """Shuffle data and labels.
Input: Input:
data: B,N,... numpy array data: B,N,... numpy array
label: B,... numpy array label: B,... numpy array
Return: Return:
shuffled data, label and shuffle indices shuffled data, label and shuffle indices
""" """
idx = np.arange(len(labels)) idx = np.arange(len(labels))
np.random.shuffle(idx) np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data): def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior. """Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch. Use the same shuffling idx for the entire batch.
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
idx = np.arange(batch_data.shape[1]) idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx) np.random.shuffle(idx)
return batch_data[:,idx,:] return batch_data[:, idx, :]
def rotate_point_cloud(batch_data): def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_z(batch_data): def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0], rotation_matrix = np.array(
[-sinval, cosval, 0], [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
[0, 0, 1]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal): def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud. """Randomly rotate XYZ, normal point cloud.
Input: Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output: Output:
B,N,6, rotated XYZ, normal point cloud B,N,6, rotated XYZ, normal point cloud
''' """
for k in range(batch_xyz_normal.shape[0]): for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_xyz_normal[k,:,0:3] shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k,:,3:6] shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) batch_xyz_normal[k, :, 0:3] = np.dot(
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations def rotate_perturbation_point_cloud_with_normal(
Input: batch_data, angle_sigma=0.06, angle_clip=0.18
BxNx6 array, original batch of point clouds and point normals ):
Return: """Randomly perturb the point clouds by small rotations
BxNx3 array, rotated batch of point clouds Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
shape_pc = batch_data[k,:,0:3] [
shape_normal = batch_data[k,:,3:6] [np.cos(angles[1]), 0, np.sin(angles[1])],
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) [0, 1, 0],
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) [-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle): def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx6 array, original batch of point clouds with normal BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation scalar, angle of rotation
Return: Return:
BxNx6 array, rotated batch of point clouds iwth normal BxNx6 array, rotated batch of point clouds iwth normal
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k,:,3:6] shape_normal = batch_data[k, :, 3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_perturbation_point_cloud(
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): batch_data, angle_sigma=0.06, angle_clip=0.18
""" Randomly perturb the point clouds by small rotations ):
Input: """Randomly perturb the point clouds by small rotations
BxNx3 array, original batch of point clouds Input:
Return: BxNx3 array, original batch of point clouds
BxNx3 array, rotated batch of point clouds Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point. """Randomly jitter points. jittering is per point.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, jittered batch of point clouds BxNx3 array, jittered batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
assert(clip > 0) assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data jittered_data += batch_data
return jittered_data return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1): def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud. """Randomly shift point cloud. Shift is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, shifted batch of point clouds BxNx3 array, shifted batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3)) shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:] batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud. """Randomly scale the point cloud. Scale is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, scaled batch of point clouds BxNx3 array, scaled batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B) scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index] batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875): def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 ''' """batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]): for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] drop_idx = np.where(
if len(drop_idx)>0: np.random.random((batch_pc.shape[1])) <= dropout_ratio
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need )[0]
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc return batch_pc
from point_transformer import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse import argparse
import os import os
import tqdm import time
from functools import partial from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader import provider
import torch.nn as nn
import torch import torch
import time import torch.nn as nn
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from point_transformer import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=200) parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8) parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument('--opt', type=str, default='adam') parser.add_argument("--opt", type=str, default="adam")
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip' data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename) download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join( local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled') get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path): if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download(
download_path, verify_ssl=False) "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile from zipfile import ZipFile
with ZipFile(download_path) as z: with ZipFile(download_path) as z:
z.extractall(path=get_download_dir()) z.extractall(path=get_download_dir())
...@@ -43,7 +50,8 @@ CustomDataLoader = partial( ...@@ -43,7 +50,8 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
...@@ -60,8 +68,7 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -60,8 +68,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq: for data, label in tq:
data = data.data.numpy() data = data.data.numpy()
data = provider.random_point_dropout(data) data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud( data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data) data = torch.tensor(data)
...@@ -84,11 +91,19 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -84,11 +91,19 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss total_loss += loss
total_correct += correct total_correct += correct
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % (total_loss / num_batches), {
'AvgAcc': '%.5f' % (total_correct / count)}) "AvgLoss": "%.5f" % (total_loss / num_batches),
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / "AvgAcc": "%.5f" % (total_correct / count),
num_batches, total_correct / count, time.time() - start_time)) }
)
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches,
total_correct / count,
time.time() - start_time,
)
)
scheduler.step() scheduler.step()
...@@ -111,10 +126,12 @@ def evaluate(net, test_loader, dev): ...@@ -111,10 +126,12 @@ def evaluate(net, test_loader, dev):
total_correct += correct total_correct += correct
count += num_examples count += num_examples
tq.set_postfix({ tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
'AvgAcc': '%.5f' % (total_correct / count)}) print(
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format( "[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time)) total_correct / count, time.time() - start_time
)
)
return total_correct / count return total_correct / count
...@@ -125,13 +142,15 @@ net = net.to(dev) ...@@ -125,13 +142,15 @@ net = net.to(dev)
if args.load_model_path: if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'sgd': if args.opt == "sgd":
# The optimizer strategy described in paper: # The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01, opt = torch.optim.SGD(
momentum=0.9, weight_decay=1e-4) net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1) opt, milestones=[120, 160], gamma=0.1
elif args.opt == 'adam': )
elif args.opt == "adam":
# The optimizer strategy proposed by # The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers: # https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam( opt = torch.optim.Adam(
...@@ -139,16 +158,26 @@ elif args.opt == 'adam': ...@@ -139,16 +158,26 @@ elif args.opt == 'adam':
lr=1e-3, lr=1e-3,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-08, eps=1e-08,
weight_decay=1e-4 weight_decay=1e-4,
) )
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3) scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train') train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split='test') test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0 best_test_acc = 0
...@@ -161,6 +190,5 @@ for epoch in range(args.num_epochs): ...@@ -161,6 +190,5 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc best_test_acc = test_acc
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % ( print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
test_acc, best_test_acc))
print() print()
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from functools import partial
import tqdm
import argparse import argparse
import time import time
from functools import partial
import numpy as np
import torch
import torch.optim as optim
import tqdm
from point_transformer import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet from ShapeNet import ShapeNet
from point_transformer import PointTransformerSeg, PartSegLoss from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=250) parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument('--num-workers', type=int, default=8) parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument('--tensorboard', action='store_true') parser.add_argument("--tensorboard", action="store_true")
parser.add_argument('--opt', type=str, default='adam') parser.add_argument("--opt", type=str, default="adam")
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
...@@ -37,7 +37,8 @@ CustomDataLoader = partial( ...@@ -37,7 +37,8 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
...@@ -58,8 +59,11 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -58,8 +59,11 @@ def train(net, opt, scheduler, train_loader, dev):
opt.zero_grad() opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat] cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category # An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to( cat_tensor = (
dev, dtype=torch.float).repeat(1, 2048) torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16) cat_tensor = cat_tensor.view(num_examples, -1, 16)
logits = net(data, cat_tensor).permute(0, 2, 1) logits = net(data, cat_tensor).permute(0, 2, 1)
loss = L(logits, label) loss = L(logits, label)
...@@ -78,14 +82,17 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -78,14 +82,17 @@ def train(net, opt, scheduler, train_loader, dev):
AvgLoss = total_loss / num_batches AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count AvgAcc = total_correct / count
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % AvgLoss, {"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
'AvgAcc': '%.5f' % AvgAcc}) )
scheduler.step() scheduler.step()
end = time.time() end = time.time()
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / print(
num_batches, total_correct / count, end - start)) "[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
return data, preds, AvgLoss, AvgAcc, end-start total_loss / num_batches, total_correct / count, end - start
)
)
return data, preds, AvgLoss, AvgAcc, end - start
def mIoU(preds, label, cat, cat_miou, seg_classes): def mIoU(preds, label, cat, cat_miou, seg_classes):
...@@ -128,27 +135,39 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False): ...@@ -128,27 +135,39 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
data = data.to(dev, dtype=torch.float) data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long) label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat] cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to( cat_tensor = (
dev, dtype=torch.float).repeat(1, 2048) torch.tensor(eye_mat[cat_ind])
cat_tensor = cat_tensor.view( .to(dev, dtype=torch.float)
num_examples, -1, 16) .repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16)
logits = net(data, cat_tensor).permute(0, 2, 1) logits = net(data, cat_tensor).permute(0, 2, 1)
_, preds = logits.max(1) _, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(), cat_miou = mIoU(
label.view(num_examples, -1).cpu().numpy(), preds.cpu().numpy(),
cat, cat_miou, shapenet.seg_classes) label.view(num_examples, -1).cpu().numpy(),
cat,
cat_miou,
shapenet.seg_classes,
)
for _, v in cat_miou.items(): for _, v in cat_miou.items():
if v[1] > 0: if v[1] > 0:
miou += v[0] miou += v[0]
count += v[1] count += v[1]
per_cat_miou += v[0] / v[1] per_cat_miou += v[0] / v[1]
per_cat_count += 1 per_cat_count += 1
tq.set_postfix({ tq.set_postfix(
'mIoU': '%.5f' % (miou / count), {
'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)}) "mIoU": "%.5f" % (miou / count),
print("[Test] mIoU: %.5f, per Category mIoU: %.5f" % "per Category mIoU": "%.5f"
(miou / count, per_cat_miou / per_cat_count)) % (per_cat_miou / per_cat_count),
}
)
print(
"[Test] mIoU: %.5f, per Category mIoU: %.5f"
% (miou / count, per_cat_miou / per_cat_count)
)
if per_cat_verbose: if per_cat_verbose:
print("-" * 60) print("-" * 60)
print("Per-Category mIoU:") print("Per-Category mIoU:")
...@@ -168,13 +187,15 @@ net = net.to(dev) ...@@ -168,13 +187,15 @@ net = net.to(dev)
if args.load_model_path: if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'sgd': if args.opt == "sgd":
# The optimizer strategy described in paper: # The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01, opt = torch.optim.SGD(
momentum=0.9, weight_decay=1e-4) net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1) opt, milestones=[120, 160], gamma=0.1
elif args.opt == 'adam': )
elif args.opt == "adam":
# The optimizer strategy proposed by # The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers: # https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam( opt = torch.optim.Adam(
...@@ -182,7 +203,7 @@ elif args.opt == 'adam': ...@@ -182,7 +203,7 @@ elif args.opt == 'adam':
lr=1e-3, lr=1e-3,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-08, eps=1e-08,
weight_decay=1e-4 weight_decay=1e-4,
) )
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3) scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
...@@ -198,20 +219,63 @@ if args.tensorboard: ...@@ -198,20 +219,63 @@ if args.tensorboard:
import torchvision import torchvision
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms from torchvision import datasets, transforms
writer = SummaryWriter() writer = SummaryWriter()
# Select 50 distinct colors for different parts # Select 50 distinct colors for different parts
color_map = torch.tensor([ color_map = torch.tensor(
[47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [ [
128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113], [47, 79, 79],
[205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [ [139, 69, 19],
0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139], [112, 128, 144],
[10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [ [85, 107, 47],
255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122], [139, 0, 0],
[220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [ [128, 128, 0],
218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237], [72, 61, 139],
[128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [ [0, 128, 0],
175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180], [188, 143, 143],
]) [60, 179, 113],
[205, 133, 63],
[0, 139, 139],
[70, 130, 180],
[205, 92, 92],
[154, 205, 50],
[0, 0, 139],
[50, 205, 50],
[250, 250, 250],
[218, 165, 32],
[139, 0, 139],
[10, 10, 10],
[176, 48, 96],
[72, 209, 204],
[153, 50, 204],
[255, 69, 0],
[255, 145, 0],
[0, 0, 205],
[255, 255, 0],
[0, 255, 0],
[233, 150, 122],
[220, 20, 60],
[0, 191, 255],
[160, 32, 240],
[192, 192, 192],
[173, 255, 47],
[218, 112, 214],
[216, 191, 216],
[255, 127, 80],
[255, 0, 255],
[100, 149, 237],
[128, 128, 128],
[221, 160, 221],
[144, 238, 144],
[123, 104, 238],
[255, 160, 122],
[175, 238, 238],
[238, 130, 238],
[127, 255, 212],
[255, 218, 185],
[255, 105, 180],
]
)
# paint each point according to its pred # paint each point according to its pred
...@@ -227,28 +291,38 @@ best_test_per_cat_miou = 0 ...@@ -227,28 +291,38 @@ best_test_per_cat_miou = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch)) print("Epoch #{}: ".format(epoch))
data, preds, AvgLoss, AvgAcc, training_time = train( data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev) net, opt, scheduler, train_loader, dev
)
if (epoch + 1) % 5 == 0 or epoch == 0: if (epoch + 1) % 5 == 0 or epoch == 0:
test_miou, test_per_cat_miou = evaluate( test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)
net, test_loader, dev, True)
if test_miou > best_test_miou: if test_miou > best_test_miou:
best_test_miou = test_miou best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % ( print(
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou)) "Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
% (
test_miou,
best_test_miou,
test_per_cat_miou,
best_test_per_cat_miou,
)
)
# Tensorboard # Tensorboard
if args.tensorboard: if args.tensorboard:
colored = paint(preds) colored = paint(preds)
writer.add_mesh('data', vertices=data, writer.add_mesh(
colors=colored, global_step=epoch) "data", vertices=data, colors=colored, global_step=epoch
writer.add_scalar('training time for one epoch', )
training_time, global_step=epoch) writer.add_scalar(
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch) "training time for one epoch", training_time, global_step=epoch
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch) )
writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch) writer.add_scalar("test mIoU", test_miou, global_step=epoch)
writer.add_scalar('best test mIoU', writer.add_scalar(
best_test_miou, global_step=epoch) "best test mIoU", best_test_miou, global_step=epoch
)
print() print()
import numpy as np
import warnings
import os import os
import warnings
import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc): def pc_normalize(pc):
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
...@@ -11,6 +14,7 @@ def pc_normalize(pc): ...@@ -11,6 +14,7 @@ def pc_normalize(pc):
pc = pc / m pc = pc / m
return pc return pc
def farthest_point_sample(point, npoint): def farthest_point_sample(point, npoint):
""" """
Farthest point sampler works as follows: Farthest point sampler works as follows:
...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint): ...@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint):
centroids: sampled pointcloud index, [npoint, D] centroids: sampled pointcloud index, [npoint, D]
""" """
N, D = point.shape N, D = point.shape
xyz = point[:,:3] xyz = point[:, :3]
centroids = np.zeros((npoint,)) centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10 distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N) farthest = np.random.randint(0, N)
...@@ -39,12 +43,20 @@ def farthest_point_sample(point, npoint): ...@@ -39,12 +43,20 @@ def farthest_point_sample(point, npoint):
point = point[centroids.astype(np.int32)] point = point[centroids.astype(np.int32)]
return point return point
class ModelNetDataLoader(Dataset): class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False, def __init__(
normal_channel=True, cache_size=15000): self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
""" """
Input: Input:
root: the root path to the local data files root: the root path to the local data files
npoint: number of points from each cloud npoint: number of points from each cloud
split: which split of the data, 'train' or 'test' split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler fps: whether to sample points with farthest point sampler
...@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset): ...@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root self.root = root
self.npoints = npoint self.npoints = npoint
self.fps = fps self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)] self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel self.normal_channel = normal_channel
shape_ids = {} shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] shape_ids["train"] = [
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert (split == 'train' or split == 'test') assert split == "train" or split == "test"
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple # list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i self.datapath = [
in range(len(shape_ids[split]))] (
print('The size of %s data is %d'%(split,len(self.datapath))) shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size self.cache_size = cache_size
self.cache = {} self.cache = {}
...@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset): ...@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index] fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]] cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32) cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps: if self.fps:
point_set = farthest_point_sample(point_set, self.npoints) point_set = farthest_point_sample(point_set, self.npoints)
else: else:
point_set = point_set[0:self.npoints,:] point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
...@@ -101,4 +125,4 @@ class ModelNetDataLoader(Dataset): ...@@ -101,4 +125,4 @@ class ModelNetDataLoader(Dataset):
return point_set, cls return point_set, cls
def __getitem__(self, index): def __getitem__(self, index):
return self._get_item(index) return self._get_item(index)
\ No newline at end of file
import os, json, tqdm import json
import numpy as np import os
import dgl
from zipfile import ZipFile from zipfile import ZipFile
from torch.utils.data import Dataset
import numpy as np
import tqdm
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir from dgl.data.utils import download, get_download_dir
class ShapeNet(object): class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True): def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points self.num_points = num_points
...@@ -13,8 +18,13 @@ class ShapeNet(object): ...@@ -13,8 +18,13 @@ class ShapeNet(object):
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir() download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" data_filename = (
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal") "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
)
data_path = os.path.join(
download_path,
"shapenetcore_partanno_segmentation_benchmark_v0_normal",
)
if not os.path.exists(data_path): if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename) local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path): if not os.path.exists(local_path):
...@@ -24,52 +34,72 @@ class ShapeNet(object): ...@@ -24,52 +34,72 @@ class ShapeNet(object):
synset_file = "synsetoffset2category.txt" synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f: with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()] synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
self.synset_dict = {} self.synset_dict = {}
for syn in synset: for syn in synset:
self.synset_dict[syn[1]] = syn[0] self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3], self.seg_classes = {
'Bag': [4, 5], "Airplane": [0, 1, 2, 3],
'Cap': [6, 7], "Bag": [4, 5],
'Car': [8, 9, 10, 11], "Cap": [6, 7],
'Chair': [12, 13, 14, 15], "Car": [8, 9, 10, 11],
'Earphone': [16, 17, 18], "Chair": [12, 13, 14, 15],
'Guitar': [19, 20, 21], "Earphone": [16, 17, 18],
'Knife': [22, 23], "Guitar": [19, 20, 21],
'Lamp': [24, 25, 26, 27], "Knife": [22, 23],
'Laptop': [28, 29], "Lamp": [24, 25, 26, 27],
'Motorbike': [30, 31, 32, 33, 34, 35], "Laptop": [28, 29],
'Mug': [36, 37], "Motorbike": [30, 31, 32, 33, 34, 35],
'Pistol': [38, 39, 40], "Mug": [36, 37],
'Rocket': [41, 42, 43], "Pistol": [38, 39, 40],
'Skateboard': [44, 45, 46], "Rocket": [41, 42, 43],
'Table': [47, 48, 49]} "Skateboard": [44, 45, 46],
"Table": [47, 48, 49],
train_split_json = 'shuffled_train_file_list.json' }
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json' train_split_json = "shuffled_train_file_list.json"
split_path = os.path.join(data_path, 'train_test_split') val_split_json = "shuffled_val_file_list.json"
test_split_json = "shuffled_test_file_list.json"
split_path = os.path.join(data_path, "train_test_split")
with open(os.path.join(split_path, train_split_json)) as f: with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read() tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.train_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, val_split_json)) as f: with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read() tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.val_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, test_split_json)) as f: with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read() tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.test_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
def train(self): def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "train", self.num_points, self.normal_channel
)
def valid(self): def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "valid", self.num_points, self.normal_channel
)
def trainval(self): def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "trainval", self.num_points, self.normal_channel
)
def test(self): def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "test", self.num_points, self.normal_channel
)
class ShapeNetDataset(Dataset): class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True): def __init__(self, shapenet, mode, num_points, normal_channel=True):
...@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset): ...@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset):
else: else:
self.dim = 6 self.dim = 6
if mode == 'train': if mode == "train":
self.file_list = shapenet.train_file_list self.file_list = shapenet.train_file_list
elif mode == 'valid': elif mode == "valid":
self.file_list = shapenet.val_file_list self.file_list = shapenet.val_file_list
elif mode == 'test': elif mode == "test":
self.file_list = shapenet.test_file_list self.file_list = shapenet.test_file_list
elif mode == 'trainval': elif mode == "trainval":
self.file_list = shapenet.train_file_list + shapenet.val_file_list self.file_list = shapenet.train_file_list + shapenet.val_file_list
else: else:
raise "Not supported `mode`" raise "Not supported `mode`"
...@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset): ...@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset):
data_list = [] data_list = []
label_list = [] label_list = []
category_list = [] category_list = []
print('Loading data from split ' + self.mode) print("Loading data from split " + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True): for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f: with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float) data = np.array(
data_list.append(data[:, 0:self.dim]) [t.split("\n")[0].split(" ") for t in f.readlines()]
).astype(np.float)
data_list.append(data[:, 0 : self.dim])
label_list.append(data[:, 6].astype(np.int)) label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]]) category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
self.data = data_list self.data = data_list
self.label = label_list self.label = label_list
self.category = category_list self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3): def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size]) xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size]) xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32') x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x return x
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __getitem__(self, i): def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True) inds = np.random.choice(
x = self.data[i][inds,:self.dim] self.data[i].shape[0], self.num_points, replace=True
)
x = self.data[i][inds, : self.dim]
y = self.label[i][inds] y = self.label[i][inds]
cat = self.category[i] cat = self.category[i]
if self.mode == 'train': if self.mode == "train":
x = self.translate(x, size=self.dim) x = self.translate(x, size=self.dim)
x = x.astype(np.float) x = x.astype(np.float)
y = y.astype(np.int) y = y.astype(np.int)
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dgl.geometry import farthest_point_sampler # dgl.geometry.pytorch -> dgl.geometry from dgl.geometry import (
farthest_point_sampler,
) # dgl.geometry.pytorch -> dgl.geometry
''' """
Part of the code are adapted from Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
def square_distance(src, dst): def square_distance(src, dst):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
view_shape = list(idx.shape) view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1) view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points return new_points
class FixedRadiusNearNeighbors(nn.Module): class FixedRadiusNearNeighbors(nn.Module):
''' """
Ball Query - Find the neighbors with-in a fixed radius Ball Query - Find the neighbors with-in a fixed radius
''' """
def __init__(self, radius, n_neighbor): def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__() super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius self.radius = radius
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, pos, centroids): def forward(self, pos, centroids):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = pos.device device = pos.device
B, N, _ = pos.shape B, N, _ = pos.shape
center_pos = index_points(pos, centroids) center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape _, S, _ = center_pos.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) group_idx = (
torch.arange(N, dtype=torch.long)
.to(device)
.view(1, 1, N)
.repeat([B, S, 1])
)
sqrdists = square_distance(center_pos, pos) sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius ** 2] = N group_idx[sqrdists > self.radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :self.n_neighbor] group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor]) group_first = (
group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
)
mask = group_idx == N mask = group_idx == N
group_idx[mask] = group_first[mask] group_idx[mask] = group_first[mask]
return group_idx return group_idx
class FixedRadiusNNGraph(nn.Module): class FixedRadiusNNGraph(nn.Module):
''' """
Build NN graph Build NN graph
''' """
def __init__(self, radius, n_neighbor): def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__() super(FixedRadiusNNGraph, self).__init__()
self.radius = radius self.radius = radius
...@@ -86,50 +107,58 @@ class FixedRadiusNNGraph(nn.Module): ...@@ -86,50 +107,58 @@ class FixedRadiusNNGraph(nn.Module):
unified = torch.cat([src, dst]) unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True) uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]] src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0]:] dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx)) g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq] g.ndata["pos"] = pos[i][uniq]
g.ndata['center'] = center[uniq] g.ndata["center"] = center[uniq]
if feat is not None: if feat is not None:
g.ndata['feat'] = feat[i][uniq] g.ndata["feat"] = feat[i][uniq]
glist.append(g) glist.append(g)
bg = dgl.batch(glist) bg = dgl.batch(glist)
return bg return bg
class RelativePositionMessage(nn.Module): class RelativePositionMessage(nn.Module):
''' """
Compute the input feature from neighbors Compute the input feature from neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__() super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, edges): def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos'] pos = edges.src["pos"] - edges.dst["pos"]
if 'feat' in edges.src: if "feat" in edges.src:
res = torch.cat([pos, edges.src['feat']], 1) res = torch.cat([pos, edges.src["feat"]], 1)
else: else:
res = pos res = pos
return {'agg_feat': res} return {"agg_feat": res}
class PointNetConv(nn.Module): class PointNetConv(nn.Module):
''' """
Feature aggregation Feature aggregation
''' """
def __init__(self, sizes, batch_size): def __init__(self, sizes, batch_size):
super(PointNetConv, self).__init__() super(PointNetConv, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.bn = nn.ModuleList() self.bn = nn.ModuleList()
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1)) self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i])) self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox["agg_feat"].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1) h = (
nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
...@@ -137,12 +166,12 @@ class PointNetConv(nn.Module): ...@@ -137,12 +166,12 @@ class PointNetConv(nn.Module):
h = torch.max(h, 2)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {"new_feat": h}
def group_all(self, pos, feat): def group_all(self, pos, feat):
''' """
Feature aggregation and pooling for the non-sampling layer Feature aggregation and pooling for the non-sampling layer
''' """
if feat is not None: if feat is not None:
h = torch.cat([pos, feat], 2) h = torch.cat([pos, feat], 2)
else: else:
...@@ -158,12 +187,21 @@ class PointNetConv(nn.Module): ...@@ -158,12 +187,21 @@ class PointNetConv(nn.Module):
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D] h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h return new_pos, h
class SAModule(nn.Module): class SAModule(nn.Module):
""" """
The Set Abstraction Layer The Set Abstraction Layer
""" """
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
group_all=False): def __init__(
self,
npoints,
batch_size,
radius,
mlp_sizes,
n_neighbor=64,
group_all=False,
):
super(SAModule, self).__init__() super(SAModule, self).__init__()
self.group_all = group_all self.group_all = group_all
if not group_all: if not group_all:
...@@ -181,18 +219,22 @@ class SAModule(nn.Module): ...@@ -181,18 +219,22 @@ class SAModule(nn.Module):
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim) feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res return pos_res, feat_res
class SAMSGModule(nn.Module): class SAMSGModule(nn.Module):
""" """
The Set Abstraction Multi-Scale grouping Layer The Set Abstraction Multi-Scale grouping Layer
""" """
def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list):
def __init__(
self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list
):
super(SAMSGModule, self).__init__() super(SAMSGModule, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.group_size = len(radius_list) self.group_size = len(radius_list)
...@@ -202,9 +244,12 @@ class SAMSGModule(nn.Module): ...@@ -202,9 +244,12 @@ class SAMSGModule(nn.Module):
self.message_list = nn.ModuleList() self.message_list = nn.ModuleList()
self.conv_list = nn.ModuleList() self.conv_list = nn.ModuleList()
for i in range(self.group_size): for i in range(self.group_size):
self.frnn_graph_list.append(FixedRadiusNNGraph(radius_list[i], self.frnn_graph_list.append(
n_neighbor_list[i])) FixedRadiusNNGraph(radius_list[i], n_neighbor_list[i])
self.message_list.append(RelativePositionMessage(n_neighbor_list[i])) )
self.message_list.append(
RelativePositionMessage(n_neighbor_list[i])
)
self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size)) self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))
def forward(self, pos, feat): def forward(self, pos, feat):
...@@ -214,21 +259,27 @@ class SAMSGModule(nn.Module): ...@@ -214,21 +259,27 @@ class SAMSGModule(nn.Module):
for i in range(self.group_size): for i in range(self.group_size):
g = self.frnn_graph_list[i](pos, centroids, feat) g = self.frnn_graph_list[i](pos, centroids, feat)
g.update_all(self.message_list[i], self.conv_list[i]) g.update_all(self.message_list[i], self.conv_list[i])
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
if i == 0: if i == 0:
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim) self.batch_size, -1, pos_dim
)
feat_res = g.ndata["new_feat"][mask].view(
self.batch_size, -1, feat_dim
)
feat_res_list.append(feat_res) feat_res_list.append(feat_res)
feat_res = torch.cat(feat_res_list, 2) feat_res = torch.cat(feat_res_list, 2)
return pos_res, feat_res return pos_res, feat_res
class PointNet2FP(nn.Module): class PointNet2FP(nn.Module):
""" """
The Feature Propagation Layer The Feature Propagation Layer
""" """
def __init__(self, input_dims, sizes): def __init__(self, input_dims, sizes):
super(PointNet2FP, self).__init__() super(PointNet2FP, self).__init__()
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
...@@ -236,7 +287,7 @@ class PointNet2FP(nn.Module): ...@@ -236,7 +287,7 @@ class PointNet2FP(nn.Module):
sizes = [input_dims] + sizes sizes = [input_dims] + sizes
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1)) self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i])) self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2): def forward(self, x1, x2, feat1, feat2):
...@@ -263,7 +314,9 @@ class PointNet2FP(nn.Module): ...@@ -263,7 +314,9 @@ class PointNet2FP(nn.Module):
dist_recip = 1.0 / (dists + 1e-8) dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True) norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm weight = dist_recip / norm
interpolated_feat = torch.sum(index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2) interpolated_feat = torch.sum(
index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None: if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1) new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
...@@ -278,14 +331,21 @@ class PointNet2FP(nn.Module): ...@@ -278,14 +331,21 @@ class PointNet2FP(nn.Module):
class PointNet2SSGCls(nn.Module): class PointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4): def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(PointNet2SSGCls, self).__init__() super(PointNet2SSGCls, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128]) self.sa_module1 = SAModule(
self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256]) 512, batch_size, 0.2, [input_dims, 64, 64, 128]
self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024], )
group_all=True) self.sa_module2 = SAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = SAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = nn.Linear(1024, 512) self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512) self.bn1 = nn.BatchNorm1d(512)
...@@ -320,19 +380,39 @@ class PointNet2SSGCls(nn.Module): ...@@ -320,19 +380,39 @@ class PointNet2SSGCls(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
return out return out
class PointNet2MSGCls(nn.Module): class PointNet2MSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4): def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(PointNet2MSGCls, self).__init__() super(PointNet2MSGCls, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [16, 32, 128], self.sa_msg_module1 = SAMSGModule(
[[input_dims, 32, 32, 64], [input_dims, 64, 64, 128], 512,
[input_dims, 64, 96, 128]]) batch_size,
self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.2, 0.4, 0.8], [32, 64, 128], [0.1, 0.2, 0.4],
[[320 + 3, 64, 64, 128], [320 + 3, 128, 128, 256], [16, 32, 128],
[320 + 3, 128, 128, 256]]) [
self.sa_module3 = SAModule(None, batch_size, None, [640 + 3, 256, 512, 1024], [input_dims, 32, 32, 64],
group_all=True) [input_dims, 64, 64, 128],
[input_dims, 64, 96, 128],
],
)
self.sa_msg_module2 = SAMSGModule(
128,
batch_size,
[0.2, 0.4, 0.8],
[32, 64, 128],
[
[320 + 3, 64, 64, 128],
[320 + 3, 128, 128, 256],
[320 + 3, 128, 128, 256],
],
)
self.sa_module3 = SAModule(
None, batch_size, None, [640 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = nn.Linear(1024, 512) self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512) self.bn1 = nn.BatchNorm1d(512)
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pointnet2 import PointNet2FP, SAModule, SAMSGModule
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
from pointnet2 import SAModule, SAMSGModule, PointNet2FP
class PointNet2SSGPartSeg(nn.Module): class PointNet2SSGPartSeg(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=6): def __init__(self, output_classes, batch_size, input_dims=6):
super(PointNet2SSGPartSeg, self).__init__() super(PointNet2SSGPartSeg, self).__init__()
#if normal_channel == true, input_dims = 6+3 # if normal_channel == true, input_dims = 6+3
self.input_dims = input_dims self.input_dims = input_dims
self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32) self.sa_module1 = SAModule(
self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256]) 512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32
self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024], )
group_all=True) self.sa_module2 = SAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = SAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.fp3 = PointNet2FP(1280, [256, 256]) self.fp3 = PointNet2FP(1280, [256, 256])
self.fp2 = PointNet2FP(384, [256, 128]) self.fp2 = PointNet2FP(384, [256, 128])
# if normal_channel == true, 128+16+6+3 # if normal_channel == true, 128+16+6+3
self.fp1 = PointNet2FP(128+16+6, [128, 128, 128]) self.fp1 = PointNet2FP(128 + 16 + 6, [128, 128, 128])
self.conv1 = nn.Conv1d(128, 128, 1) self.conv1 = nn.Conv1d(128, 128, 1)
self.bn1 = nn.BatchNorm1d(128) self.bn1 = nn.BatchNorm1d(128)
...@@ -38,7 +44,9 @@ class PointNet2SSGPartSeg(nn.Module): ...@@ -38,7 +44,9 @@ class PointNet2SSGPartSeg(nn.Module):
l2_pos, l2_feat = self.sa_module2(l1_pos, l1_feat) l2_pos, l2_feat = self.sa_module2(l1_pos, l1_feat)
l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat) # [B, N, C], [B, D] l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat) # [B, N, C], [B, D]
# Feature Propagation layers # Feature Propagation layers
l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1)) # l2_feat: [B, D, N] l2_feat = self.fp3(
l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1)
) # l2_feat: [B, D, N]
l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1)) l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))
l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2) l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)
l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1)) l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))
...@@ -53,13 +61,30 @@ class PointNet2MSGPartSeg(nn.Module): ...@@ -53,13 +61,30 @@ class PointNet2MSGPartSeg(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=6): def __init__(self, output_classes, batch_size, input_dims=6):
super(PointNet2MSGPartSeg, self).__init__() super(PointNet2MSGPartSeg, self).__init__()
self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [32, 64, 128], self.sa_msg_module1 = SAMSGModule(
[[input_dims, 32, 32, 64], [input_dims, 64, 64, 128], 512,
[input_dims, 64, 96, 128]]) batch_size,
self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.4, 0.8], [64, 128], [0.1, 0.2, 0.4],
[[128+128+64 +3, 128, 128, 256], [128+128+64 +3, 128, 196, 256]]) [32, 64, 128],
self.sa_module3 = SAModule(None, batch_size, None, [512 + 3, 256, 512, 1024], [
group_all=True) [input_dims, 32, 32, 64],
[input_dims, 64, 64, 128],
[input_dims, 64, 96, 128],
],
)
self.sa_msg_module2 = SAMSGModule(
128,
batch_size,
[0.4, 0.8],
[64, 128],
[
[128 + 128 + 64 + 3, 128, 128, 256],
[128 + 128 + 64 + 3, 128, 196, 256],
],
)
self.sa_module3 = SAModule(
None, batch_size, None, [512 + 3, 256, 512, 1024], group_all=True
)
self.fp3 = PointNet2FP(1536, [256, 256]) self.fp3 = PointNet2FP(1536, [256, 256])
self.fp2 = PointNet2FP(576, [256, 128]) self.fp2 = PointNet2FP(576, [256, 128])
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
class PointNetCls(nn.Module): class PointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64, def __init__(
dropout_prob=0.5, use_transform=True): self,
output_classes,
input_dims=3,
conv1_dim=64,
dropout_prob=0.5,
use_transform=True,
):
super(PointNetCls, self).__init__() super(PointNetCls, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
self.conv1 = nn.ModuleList() self.conv1 = nn.ModuleList()
...@@ -85,6 +92,7 @@ class PointNetCls(nn.Module): ...@@ -85,6 +92,7 @@ class PointNetCls(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
return out return out
class TransformNet(nn.Module): class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64): def __init__(self, input_dims=3, conv1_dim=64):
super(TransformNet, self).__init__() super(TransformNet, self).__init__()
...@@ -118,7 +126,7 @@ class TransformNet(nn.Module): ...@@ -118,7 +126,7 @@ class TransformNet(nn.Module):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
h = F.relu(h) h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len) h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2): for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h) h = mlp(h)
...@@ -127,8 +135,14 @@ class TransformNet(nn.Module): ...@@ -127,8 +135,14 @@ class TransformNet(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32))) iden = Variable(
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1) torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda: if out.is_cuda:
iden = iden.cuda() iden = iden.cuda()
out = out + iden out = out + iden
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
class PointNetPartSeg(nn.Module): class PointNetPartSeg(nn.Module):
def __init__(self, output_classes, input_dims=3, num_points=2048, def __init__(
use_transform=True): self, output_classes, input_dims=3, num_points=2048, use_transform=True
):
super(PointNetPartSeg, self).__init__() super(PointNetPartSeg, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
...@@ -33,7 +35,7 @@ class PointNetPartSeg(nn.Module): ...@@ -33,7 +35,7 @@ class PointNetPartSeg(nn.Module):
self.pool_feat_len = 2048 self.pool_feat_len = 2048
self.conv3 = nn.ModuleList() self.conv3 = nn.ModuleList()
self.conv3.append(nn.Conv1d(2048 + 64 + 128*3 + 512 + 16, 256, 1)) self.conv3.append(nn.Conv1d(2048 + 64 + 128 * 3 + 512 + 16, 256, 1))
self.conv3.append(nn.Conv1d(256, 256, 1)) self.conv3.append(nn.Conv1d(256, 256, 1))
self.conv3.append(nn.Conv1d(256, 128, 1)) self.conv3.append(nn.Conv1d(256, 128, 1))
...@@ -98,6 +100,7 @@ class PointNetPartSeg(nn.Module): ...@@ -98,6 +100,7 @@ class PointNetPartSeg(nn.Module):
out = self.conv_out(h) out = self.conv_out(h)
return out return out
class TransformNet(nn.Module): class TransformNet(nn.Module):
def __init__(self, input_dims=3, num_points=2048): def __init__(self, input_dims=3, num_points=2048):
super(TransformNet, self).__init__() super(TransformNet, self).__init__()
...@@ -131,7 +134,7 @@ class TransformNet(nn.Module): ...@@ -131,7 +134,7 @@ class TransformNet(nn.Module):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
h = F.relu(h) h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len) h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2): for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h) h = mlp(h)
...@@ -140,20 +143,27 @@ class TransformNet(nn.Module): ...@@ -140,20 +143,27 @@ class TransformNet(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32))) iden = Variable(
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1) torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda: if out.is_cuda:
iden = iden.cuda() iden = iden.cuda()
out = out + iden out = out + iden
out = out.view(-1, self.input_dims, self.input_dims) out = out.view(-1, self.input_dims, self.input_dims)
return out return out
class PartSegLoss(nn.Module): class PartSegLoss(nn.Module):
def __init__(self, eps=0.2): def __init__(self, eps=0.2):
super(PartSegLoss, self).__init__() super(PartSegLoss, self).__init__()
self.eps = eps self.eps = eps
self.loss = nn.CrossEntropyLoss() self.loss = nn.CrossEntropyLoss()
def forward(self, logits, y): def forward(self, logits, y):
num_classes = logits.shape[1] num_classes = logits.shape[1]
logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes) logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)
......
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
''' """
import numpy as np import numpy as np
def normalize_data(batch_data): def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin, """Normalize the batch data, use coordinates of the block centered at origin,
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C)) normal_data = np.zeros((B, N, C))
...@@ -16,237 +17,296 @@ def normalize_data(batch_data): ...@@ -16,237 +17,296 @@ def normalize_data(batch_data):
pc = batch_data[b] pc = batch_data[b]
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m pc = pc / m
normal_data[b] = pc normal_data[b] = pc
return normal_data return normal_data
def shuffle_data(data, labels): def shuffle_data(data, labels):
""" Shuffle data and labels. """Shuffle data and labels.
Input: Input:
data: B,N,... numpy array data: B,N,... numpy array
label: B,... numpy array label: B,... numpy array
Return: Return:
shuffled data, label and shuffle indices shuffled data, label and shuffle indices
""" """
idx = np.arange(len(labels)) idx = np.arange(len(labels))
np.random.shuffle(idx) np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data): def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior. """Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch. Use the same shuffling idx for the entire batch.
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
idx = np.arange(batch_data.shape[1]) idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx) np.random.shuffle(idx)
return batch_data[:,idx,:] return batch_data[:, idx, :]
def rotate_point_cloud(batch_data): def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_z(batch_data): def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0], rotation_matrix = np.array(
[-sinval, cosval, 0], [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
[0, 0, 1]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal): def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud. """Randomly rotate XYZ, normal point cloud.
Input: Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output: Output:
B,N,6, rotated XYZ, normal point cloud B,N,6, rotated XYZ, normal point cloud
''' """
for k in range(batch_xyz_normal.shape[0]): for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_xyz_normal[k,:,0:3] shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k,:,3:6] shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) batch_xyz_normal[k, :, 0:3] = np.dot(
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations def rotate_perturbation_point_cloud_with_normal(
Input: batch_data, angle_sigma=0.06, angle_clip=0.18
BxNx6 array, original batch of point clouds and point normals ):
Return: """Randomly perturb the point clouds by small rotations
BxNx3 array, rotated batch of point clouds Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
shape_pc = batch_data[k,:,0:3] [
shape_normal = batch_data[k,:,3:6] [np.cos(angles[1]), 0, np.sin(angles[1])],
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) [0, 1, 0],
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) [-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle): def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx6 array, original batch of point clouds with normal BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation scalar, angle of rotation
Return: Return:
BxNx6 array, rotated batch of point clouds iwth normal BxNx6 array, rotated batch of point clouds iwth normal
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k,:,3:6] shape_normal = batch_data[k, :, 3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_perturbation_point_cloud(
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): batch_data, angle_sigma=0.06, angle_clip=0.18
""" Randomly perturb the point clouds by small rotations ):
Input: """Randomly perturb the point clouds by small rotations
BxNx3 array, original batch of point clouds Input:
Return: BxNx3 array, original batch of point clouds
BxNx3 array, rotated batch of point clouds Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point. """Randomly jitter points. jittering is per point.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, jittered batch of point clouds BxNx3 array, jittered batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
assert(clip > 0) assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data jittered_data += batch_data
return jittered_data return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1): def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud. """Randomly shift point cloud. Shift is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, shifted batch of point clouds BxNx3 array, shifted batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3)) shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:] batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud. """Randomly scale the point cloud. Scale is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, scaled batch of point clouds BxNx3 array, scaled batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B) scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index] batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875): def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 ''' """batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]): for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] drop_idx = np.where(
if len(drop_idx)>0: np.random.random((batch_pc.shape[1])) <= dropout_ratio
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need )[0]
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc return batch_pc
from pointnet2 import PointNet2SSGCls, PointNet2MSGCls
from pointnet_cls import PointNetCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse import argparse
import os import os
import urllib import urllib
import tqdm
from functools import partial from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl import provider
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from pointnet2 import PointNet2MSGCls, PointNet2SSGCls
from pointnet_cls import PointNetCls
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
# from dataset import ModelNet # from dataset import ModelNet
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet') parser.add_argument("--model", type=str, default="pointnet")
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=200) parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8) parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip' data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename) download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join( local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled') get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path): if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download(
download_path, verify_ssl=False) "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile from zipfile import ZipFile
with ZipFile(download_path) as z: with ZipFile(download_path) as z:
z.extractall(path=get_download_dir()) z.extractall(path=get_download_dir())
...@@ -49,7 +57,8 @@ CustomDataLoader = partial( ...@@ -49,7 +57,8 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
...@@ -65,8 +74,7 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -65,8 +74,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq: for data, label in tq:
data = data.data.numpy() data = data.data.numpy()
data = provider.random_point_dropout(data) data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud( data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data) data = torch.tensor(data)
...@@ -89,9 +97,12 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -89,9 +97,12 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss total_loss += loss
total_correct += correct total_correct += correct
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % (total_loss / num_batches), {
'AvgAcc': '%.5f' % (total_correct / count)}) "AvgLoss": "%.5f" % (total_loss / num_batches),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
scheduler.step() scheduler.step()
...@@ -114,19 +125,18 @@ def evaluate(net, test_loader, dev): ...@@ -114,19 +125,18 @@ def evaluate(net, test_loader, dev):
total_correct += correct total_correct += correct
count += num_examples count += num_examples
tq.set_postfix({ tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
'AvgAcc': '%.5f' % (total_correct / count)})
return total_correct / count return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'pointnet': if args.model == "pointnet":
net = PointNetCls(40, input_dims=6) net = PointNetCls(40, input_dims=6)
elif args.model == 'pointnet2_ssg': elif args.model == "pointnet2_ssg":
net = PointNet2SSGCls(40, batch_size, input_dims=6) net = PointNet2SSGCls(40, batch_size, input_dims=6)
elif args.model == 'pointnet2_msg': elif args.model == "pointnet2_msg":
net = PointNet2MSGCls(40, batch_size, input_dims=6) net = PointNet2MSGCls(40, batch_size, input_dims=6)
net = net.to(dev) net = net.to(dev)
...@@ -137,23 +147,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4) ...@@ -137,23 +147,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7) scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train') train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split='test') test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0 best_test_acc = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev) train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0: if (epoch + 1) % 1 == 0:
print('Epoch #%d Testing' % epoch) print("Epoch #%d Testing" % epoch)
test_acc = evaluate(net, test_loader, dev) test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc: if test_acc > best_test_acc:
best_test_acc = test_acc best_test_acc = test_acc
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % ( print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
test_acc, best_test_acc))
import argparse
import os
import time
import urllib
from functools import partial
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
from pointnet_partseg import PartSegLoss, PointNetPartSeg
from ShapeNet import ShapeNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np
import dgl import dgl
from dgl.data.utils import download, get_download_dir from dgl.data.utils import download, get_download_dir
from functools import partial
import tqdm
import urllib
import os
import argparse
import time
from ShapeNet import ShapeNet
from pointnet_partseg import PointNetPartSeg, PartSegLoss
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet') parser.add_argument("--model", type=str, default="pointnet")
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=250) parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument('--num-workers', type=int, default=4) parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument('--batch-size', type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument('--tensorboard', action='store_true') parser.add_argument("--tensorboard", action="store_true")
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
def collate(samples): def collate(samples):
graphs, cat = map(list, zip(*samples)) graphs, cat = map(list, zip(*samples))
return dgl.batch(graphs), cat return dgl.batch(graphs), cat
CustomDataLoader = partial( CustomDataLoader = partial(
DataLoader, DataLoader,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
category_list = sorted(list(shapenet.seg_classes.keys())) category_list = sorted(list(shapenet.seg_classes.keys()))
...@@ -61,8 +65,12 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -61,8 +65,12 @@ def train(net, opt, scheduler, train_loader, dev):
opt.zero_grad() opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat] cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category # An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048) cat_tensor = (
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1) torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0, 2, 1)
logits = net(data, cat_tensor) logits = net(data, cat_tensor)
loss = L(logits, label) loss = L(logits, label)
loss.backward() loss.backward()
...@@ -80,20 +88,21 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -80,20 +88,21 @@ def train(net, opt, scheduler, train_loader, dev):
AvgLoss = total_loss / num_batches AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count AvgAcc = total_correct / count
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % AvgLoss, {"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
'AvgAcc': '%.5f' % AvgAcc}) )
scheduler.step() scheduler.step()
end = time.time() end = time.time()
return data, preds, AvgLoss, AvgAcc, end-start return data, preds, AvgLoss, AvgAcc, end - start
def mIoU(preds, label, cat, cat_miou, seg_classes): def mIoU(preds, label, cat, cat_miou, seg_classes):
for i in range(preds.shape[0]): for i in range(preds.shape[0]):
shape_iou = 0 shape_iou = 0
n = len(seg_classes[cat[i]]) n = len(seg_classes[cat[i]])
for cls in seg_classes[cat[i]]: for cls in seg_classes[cat[i]]:
pred_set = set(np.where(preds[i,:] == cls)[0]) pred_set = set(np.where(preds[i, :] == cls)[0])
label_set = set(np.where(label[i,:] == cls)[0]) label_set = set(np.where(label[i, :] == cls)[0])
union = len(pred_set.union(label_set)) union = len(pred_set.union(label_set))
inter = len(pred_set.intersection(label_set)) inter = len(pred_set.intersection(label_set))
if union == 0: if union == 0:
...@@ -106,6 +115,7 @@ def mIoU(preds, label, cat, cat_miou, seg_classes): ...@@ -106,6 +115,7 @@ def mIoU(preds, label, cat, cat_miou, seg_classes):
return cat_miou return cat_miou
def evaluate(net, test_loader, dev, per_cat_verbose=False): def evaluate(net, test_loader, dev, per_cat_verbose=False):
category_list = sorted(list(shapenet.seg_classes.keys())) category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16) eye_mat = np.eye(16)
...@@ -126,23 +136,36 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False): ...@@ -126,23 +136,36 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
data = data.to(dev, dtype=torch.float) data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long) label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat] cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048) cat_tensor = (
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1) torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(
0, 2, 1
)
logits = net(data, cat_tensor) logits = net(data, cat_tensor)
_, preds = logits.max(1) _, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(), cat_miou = mIoU(
label.view(num_examples, -1).cpu().numpy(), preds.cpu().numpy(),
cat, cat_miou, shapenet.seg_classes) label.view(num_examples, -1).cpu().numpy(),
cat,
cat_miou,
shapenet.seg_classes,
)
for _, v in cat_miou.items(): for _, v in cat_miou.items():
if v[1] > 0: if v[1] > 0:
miou += v[0] miou += v[0]
count += v[1] count += v[1]
per_cat_miou += v[0] / v[1] per_cat_miou += v[0] / v[1]
per_cat_count += 1 per_cat_count += 1
tq.set_postfix({ tq.set_postfix(
'mIoU': '%.5f' % (miou / count), {
'per Category mIoU': '%.5f' % (miou / count)}) "mIoU": "%.5f" % (miou / count),
"per Category mIoU": "%.5f" % (miou / count),
}
)
if per_cat_verbose: if per_cat_verbose:
print("Per-Category mIoU:") print("Per-Category mIoU:")
for k, v in cat_miou.items(): for k, v in cat_miou.items():
...@@ -155,11 +178,11 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False): ...@@ -155,11 +178,11 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dev = "cpu" # dev = "cpu"
if args.model == 'pointnet': if args.model == "pointnet":
net = PointNetPartSeg(50, 3, 2048) net = PointNetPartSeg(50, 3, 2048)
elif args.model == 'pointnet2_ssg': elif args.model == "pointnet2_ssg":
net = PointNet2SSGPartSeg(50, batch_size, input_dims=6) net = PointNet2SSGPartSeg(50, batch_size, input_dims=6)
elif args.model == 'pointnet2_msg': elif args.model == "pointnet2_msg":
net = PointNet2MSGPartSeg(50, batch_size, input_dims=6) net = PointNet2MSGPartSeg(50, batch_size, input_dims=6)
net = net.to(dev) net = net.to(dev)
...@@ -180,43 +203,109 @@ if args.tensorboard: ...@@ -180,43 +203,109 @@ if args.tensorboard:
import torchvision import torchvision
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms from torchvision import datasets, transforms
writer = SummaryWriter() writer = SummaryWriter()
# Select 50 distinct colors for different parts # Select 50 distinct colors for different parts
color_map = torch.tensor([ color_map = torch.tensor(
[47, 79, 79],[139, 69, 19],[112, 128, 144],[85, 107, 47],[139, 0, 0],[128, 128, 0],[72, 61, 139],[0, 128, 0],[188, 143, 143],[60, 179, 113], [
[205, 133, 63],[0, 139, 139],[70, 130, 180],[205, 92, 92],[154, 205, 50],[0, 0, 139],[50, 205, 50],[250, 250, 250],[218, 165, 32],[139, 0, 139], [47, 79, 79],
[10, 10, 10],[176, 48, 96],[72, 209, 204],[153, 50, 204],[255, 69, 0],[255, 145, 0],[0, 0, 205],[255, 255, 0],[0, 255, 0],[233, 150, 122], [139, 69, 19],
[220, 20, 60],[0, 191, 255],[160, 32, 240],[192,192,192],[173, 255, 47],[218, 112, 214],[216, 191, 216],[255, 127, 80],[255, 0, 255],[100, 149, 237], [112, 128, 144],
[128,128,128],[221, 160, 221],[144, 238, 144],[123, 104, 238],[255, 160, 122],[175, 238, 238],[238, 130, 238],[127, 255, 212],[255, 218, 185],[255, 105, 180], [85, 107, 47],
]) [139, 0, 0],
[128, 128, 0],
[72, 61, 139],
[0, 128, 0],
[188, 143, 143],
[60, 179, 113],
[205, 133, 63],
[0, 139, 139],
[70, 130, 180],
[205, 92, 92],
[154, 205, 50],
[0, 0, 139],
[50, 205, 50],
[250, 250, 250],
[218, 165, 32],
[139, 0, 139],
[10, 10, 10],
[176, 48, 96],
[72, 209, 204],
[153, 50, 204],
[255, 69, 0],
[255, 145, 0],
[0, 0, 205],
[255, 255, 0],
[0, 255, 0],
[233, 150, 122],
[220, 20, 60],
[0, 191, 255],
[160, 32, 240],
[192, 192, 192],
[173, 255, 47],
[218, 112, 214],
[216, 191, 216],
[255, 127, 80],
[255, 0, 255],
[100, 149, 237],
[128, 128, 128],
[221, 160, 221],
[144, 238, 144],
[123, 104, 238],
[255, 160, 122],
[175, 238, 238],
[238, 130, 238],
[127, 255, 212],
[255, 218, 185],
[255, 105, 180],
]
)
# paint each point according to its pred # paint each point according to its pred
def paint(batched_points): def paint(batched_points):
B, N = batched_points.shape B, N = batched_points.shape
colored = color_map[batched_points].squeeze(2) colored = color_map[batched_points].squeeze(2)
return colored return colored
best_test_miou = 0 best_test_miou = 0
best_test_per_cat_miou = 0 best_test_per_cat_miou = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
data, preds, AvgLoss, AvgAcc, training_time = train(net, opt, scheduler, train_loader, dev) data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev
)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
print('Epoch #%d Testing' % epoch) print("Epoch #%d Testing" % epoch)
test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, (epoch + 1) % 5 ==0) test_miou, test_per_cat_miou = evaluate(
net, test_loader, dev, (epoch + 1) % 5 == 0
)
if test_miou > best_test_miou: if test_miou > best_test_miou:
best_test_miou = test_miou best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % ( print(
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou)) "Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
% (
test_miou,
best_test_miou,
test_per_cat_miou,
best_test_per_cat_miou,
)
)
# Tensorboard # Tensorboard
if args.tensorboard: if args.tensorboard:
colored = paint(preds) colored = paint(preds)
writer.add_mesh('data', vertices=data, colors=colored, global_step=epoch) writer.add_mesh(
writer.add_scalar('training time for one epoch', training_time, global_step=epoch) "data", vertices=data, colors=colored, global_step=epoch
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch) )
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch) writer.add_scalar(
"training time for one epoch", training_time, global_step=epoch
)
writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch) writer.add_scalar("test mIoU", test_miou, global_step=epoch)
writer.add_scalar('best test mIoU', best_test_miou, global_step=epoch) writer.add_scalar(
"best test mIoU", best_test_miou, global_step=epoch
)
from statistics import mean
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from statistics import mean
class LogisticRegressionClassifier(nn.Module): class LogisticRegressionClassifier(nn.Module):
''' Define a logistic regression classifier to evaluate the quality of embedding results """Define a logistic regression classifier to evaluate the quality of embedding results"""
'''
def __init__(self, nfeat, nclass): def __init__(self, nfeat, nclass):
super(LogisticRegressionClassifier, self).__init__() super(LogisticRegressionClassifier, self).__init__()
self.lrc = nn.Linear(nfeat, nclass) self.lrc = nn.Linear(nfeat, nclass)
...@@ -13,7 +15,8 @@ class LogisticRegressionClassifier(nn.Module): ...@@ -13,7 +15,8 @@ class LogisticRegressionClassifier(nn.Module):
def forward(self, x): def forward(self, x):
preds = self.lrc(x) preds = self.lrc(x)
return preds return preds
def _evaluate(model, features, labels, test_mask): def _evaluate(model, features, labels, test_mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -22,10 +25,11 @@ def _evaluate(model, features, labels, test_mask): ...@@ -22,10 +25,11 @@ def _evaluate(model, features, labels, test_mask):
labels = labels[test_mask] labels = labels[test_mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def _train_test_with_lrc(model, features, labels, train_mask, test_mask): def _train_test_with_lrc(model, features, labels, train_mask, test_mask):
''' Under the pre-defined balanced train/test label setting, train a lrc to evaluate the embedding results. ''' """Under the pre-defined balanced train/test label setting, train a lrc to evaluate the embedding results."""
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-06) optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-06)
for _ in range(100): for _ in range(100):
model.train() model.train()
...@@ -34,15 +38,30 @@ def _train_test_with_lrc(model, features, labels, train_mask, test_mask): ...@@ -34,15 +38,30 @@ def _train_test_with_lrc(model, features, labels, train_mask, test_mask):
loss_train = F.cross_entropy(output[train_mask], labels[train_mask]) loss_train = F.cross_entropy(output[train_mask], labels[train_mask])
loss_train.backward() loss_train.backward()
optimizer.step() optimizer.step()
return _evaluate(model=model, features=features, labels=labels, test_mask=test_mask) return _evaluate(
model=model, features=features, labels=labels, test_mask=test_mask
)
def evaluate_embeds(features, labels, train_mask, test_mask, n_classes, cuda, test_times=10): def evaluate_embeds(
print("Training a logistic regression classifier with the pre-defined train/test split setting ...") features, labels, train_mask, test_mask, n_classes, cuda, test_times=10
):
print(
"Training a logistic regression classifier with the pre-defined train/test split setting ..."
)
res_list = [] res_list = []
for _ in range(test_times): for _ in range(test_times):
model = LogisticRegressionClassifier(nfeat=features.shape[1], nclass=n_classes) model = LogisticRegressionClassifier(
nfeat=features.shape[1], nclass=n_classes
)
if cuda: if cuda:
model.cuda() model.cuda()
res = _train_test_with_lrc(model=model, features=features, labels=labels, train_mask=train_mask, test_mask=test_mask) res = _train_test_with_lrc(
model=model,
features=features,
labels=labels,
train_mask=train_mask,
test_mask=test_mask,
)
res_list.append(res) res_list.append(res)
return mean(res_list) return mean(res_list)
import torch
import numpy as np
from collections import defaultdict from collections import defaultdict
import numpy as np
import torch
def remove_unseen_classes_from_training(train_mask, labels, removed_class): def remove_unseen_classes_from_training(train_mask, labels, removed_class):
''' Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting """Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
Input: train_mask, labels, removed_class Input: train_mask, labels, removed_class
Output: train_mask_zs: the bool list only containing seen classes Output: train_mask_zs: the bool list only containing seen classes
''' """
train_mask_zs = train_mask.clone() train_mask_zs = train_mask.clone()
for i in range(train_mask_zs.numel()): for i in range(train_mask_zs.numel()):
if train_mask_zs[i]==1 and (labels[i].item() in removed_class): if train_mask_zs[i] == 1 and (labels[i].item() in removed_class):
train_mask_zs[i]=0 train_mask_zs[i] = 0
return train_mask_zs return train_mask_zs
def get_class_set(labels): def get_class_set(labels):
''' Get the class set. """Get the class set.
Input: labels [l, [c1, c2, ..]] Input: labels [l, [c1, c2, ..]]
Output:the labeled class set dict_keys([k1, k2, ..]) Output:the labeled class set dict_keys([k1, k2, ..])
''' """
mydict = {} mydict = {}
for y in labels: for y in labels:
for label in y: for label in y:
mydict[int(label)] = 1 mydict[int(label)] = 1
return mydict.keys() return mydict.keys()
def get_label_attributes(train_mask_zs, nodeids, labellist, features): def get_label_attributes(train_mask_zs, nodeids, labellist, features):
''' Get the class-center (semanic knowledge) of each seen class. """Get the class-center (semanic knowledge) of each seen class.
Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c]) Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])
Input: train_mask_zs, nodeids, labellist, features Input: train_mask_zs, nodeids, labellist, features
Output: label_attribute{}: label -> average_labeled_node_features (class centers) Output: label_attribute{}: label -> average_labeled_node_features (class centers)
''' """
_, feat_num = features.shape _, feat_num = features.shape
labels = get_class_set(labellist) labels = get_class_set(labellist)
label_attribute_nodes = defaultdict(list) label_attribute_nodes = defaultdict(list)
...@@ -43,13 +47,14 @@ def get_label_attributes(train_mask_zs, nodeids, labellist, features): ...@@ -43,13 +47,14 @@ def get_label_attributes(train_mask_zs, nodeids, labellist, features):
label_attribute[int(label)] = np.mean(selected_features, axis=0) label_attribute[int(label)] = np.mean(selected_features, axis=0)
return label_attribute return label_attribute
def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda): def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
''' Replace the original labels by their class-centers. """Replace the original labels by their class-centers.
For each label c in the training set, the following operations will be performed: For each label c in the training set, the following operations will be performed:
Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c] Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]
Input: train_mask_zs, labels, features Input: train_mask_zs, labels, features
Output: Y_{semantic} [l, ft]: tensor Output: Y_{semantic} [l, ft]: tensor
''' """
X = torch.LongTensor(range(features.shape[0])) X = torch.LongTensor(range(features.shape[0]))
nodeids = [] nodeids = []
labellist = [] labellist = []
...@@ -59,8 +64,13 @@ def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda): ...@@ -59,8 +64,13 @@ def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
labellist.append([str(i)]) labellist.append([str(i)])
# 1. get the semantic knowledge (class centers) of all seen classes # 1. get the semantic knowledge (class centers) of all seen classes
label_attribute = get_label_attributes(train_mask_zs=train_mask_zs, nodeids=nodeids, labellist=labellist, features=features.cpu().numpy()) label_attribute = get_label_attributes(
train_mask_zs=train_mask_zs,
nodeids=nodeids,
labellist=labellist,
features=features.cpu().numpy(),
)
# 2. replace original labels by their class centers (semantic knowledge) # 2. replace original labels by their class centers (semantic knowledge)
res = np.zeros([len(nodeids), features.shape[1]]) res = np.zeros([len(nodeids), features.shape[1]])
for i, labels in enumerate(labellist): for i, labels in enumerate(labellist):
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from classify import evaluate_embeds from classify import evaluate_embeds
from label_utils import remove_unseen_classes_from_training, get_labeled_nodes_label_attribute from label_utils import (
from utils import load_data, svd_feature, process_classids get_labeled_nodes_label_attribute,
remove_unseen_classes_from_training,
)
from model import GCN, RECT_L from model import GCN, RECT_L
from utils import load_data, process_classids, svd_feature
def main(args): def main(args):
g, features, labels, train_mask, test_mask, n_classes, cuda= load_data(args) g, features, labels, train_mask, test_mask, n_classes, cuda = load_data(
args
)
# adopt any number of classes as the unseen classes (the first three classes by default) # adopt any number of classes as the unseen classes (the first three classes by default)
removed_class=args.removed_class removed_class = args.removed_class
if(len(removed_class)>n_classes): if len(removed_class) > n_classes:
raise ValueError('unseen number is greater than the number of classes: {}'.format(len(removed_class))) raise ValueError(
"unseen number is greater than the number of classes: {}".format(
len(removed_class)
)
)
for i in removed_class: for i in removed_class:
if i not in labels: if i not in labels:
raise ValueError('class out of bounds: {}'.format(i)) raise ValueError("class out of bounds: {}".format(i))
# remove these unseen classes from the training set, to construct the zero-shot label setting # remove these unseen classes from the training set, to construct the zero-shot label setting
train_mask_zs = remove_unseen_classes_from_training(train_mask=train_mask, labels=labels, removed_class=removed_class) train_mask_zs = remove_unseen_classes_from_training(
print('after removing the unseen classes, seen class labeled node num:', sum(train_mask_zs).item()) train_mask=train_mask, labels=labels, removed_class=removed_class
)
if args.model_opt == 'RECT-L': print(
model = RECT_L(g=g, in_feats=args.n_hidden, n_hidden=args.n_hidden, activation=nn.PReLU()) "after removing the unseen classes, seen class labeled node num:",
sum(train_mask_zs).item(),
)
if args.model_opt == "RECT-L":
model = RECT_L(
g=g,
in_feats=args.n_hidden,
n_hidden=args.n_hidden,
activation=nn.PReLU(),
)
if cuda: if cuda:
model.cuda() model.cuda()
features = svd_feature(features=features, d=args.n_hidden) features = svd_feature(features=features, d=args.n_hidden)
attribute_labels = get_labeled_nodes_label_attribute(train_mask_zs=train_mask_zs, labels=labels, features=features, cuda=cuda) attribute_labels = get_labeled_nodes_label_attribute(
loss_fcn = nn.MSELoss(reduction='sum') train_mask_zs=train_mask_zs,
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) labels=labels,
features=features,
cuda=cuda,
)
loss_fcn = nn.MSELoss(reduction="sum")
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
logits = model(features) logits = model(features)
loss_train = loss_fcn(attribute_labels, logits[train_mask_zs]) loss_train = loss_fcn(attribute_labels, logits[train_mask_zs])
print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item())) print(
"Epoch {:d} | Train Loss {:.5f}".format(
epoch + 1, loss_train.item()
)
)
loss_train.backward() loss_train.backward()
optimizer.step() optimizer.step()
model.eval() model.eval()
embeds = model.embed(features) embeds = model.embed(features)
elif args.model_opt == 'GCN': elif args.model_opt == "GCN":
model = GCN(g=g, in_feats=features.shape[1], model = GCN(
n_hidden=args.n_hidden, n_classes=n_classes-len(removed_class), g=g,
activation=nn.PReLU(), dropout=args.dropout) in_feats=features.shape[1],
n_hidden=args.n_hidden,
n_classes=n_classes - len(removed_class),
activation=nn.PReLU(),
dropout=args.dropout,
)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
model.train() model.train()
logits = model(features) logits = model(features)
labels_train = process_classids(labels_temp=labels[train_mask_zs]) labels_train = process_classids(labels_temp=labels[train_mask_zs])
loss_train = loss_fcn(logits[train_mask_zs], labels_train) loss_train = loss_fcn(logits[train_mask_zs], labels_train)
optimizer.zero_grad() optimizer.zero_grad()
print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item())) print(
"Epoch {:d} | Train Loss {:.5f}".format(
epoch + 1, loss_train.item()
)
)
loss_train.backward() loss_train.backward()
optimizer.step() optimizer.step()
model.eval() model.eval()
embeds = model.embed(features) embeds = model.embed(features)
elif args.model_opt == 'NodeFeats': elif args.model_opt == "NodeFeats":
embeds = svd_feature(features) embeds = svd_feature(features)
# evaluate the quality of embedding results with the original balanced labels, to assess the model performance (as suggested in the paper) # evaluate the quality of embedding results with the original balanced labels, to assess the model performance (as suggested in the paper)
res = evaluate_embeds(features=embeds, labels=labels, train_mask=train_mask, test_mask=test_mask, n_classes=n_classes, cuda=cuda) res = evaluate_embeds(
features=embeds,
labels=labels,
train_mask=train_mask,
test_mask=test_mask,
n_classes=n_classes,
cuda=cuda,
)
print("Test Accuracy of {:s}: {:.4f}".format(args.model_opt, res)) print("Test Accuracy of {:s}: {:.4f}".format(args.model_opt, res))
if __name__ == '__main__':
if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description='MODEL')
parser.add_argument("--model-opt", type=str, default='RECT-L', parser = argparse.ArgumentParser(description="MODEL")
choices=['RECT-L', 'GCN', 'NodeFeats'], parser.add_argument(
help="model option") "--model-opt",
parser.add_argument("--dataset", type=str, default='cora', type=str,
choices=['cora', 'citeseer'], default="RECT-L",
help="dataset") choices=["RECT-L", "GCN", "NodeFeats"],
parser.add_argument("--dropout", type=float, default=0.0, help="model option",
help="dropout probability") )
parser.add_argument("--gpu", type=int, default=0, parser.add_argument(
help="gpu") "--dataset",
parser.add_argument("--removed-class", type=int, nargs='*', default=[0, 1, 2], type=str,
help="remove the unseen classes") default="cora",
parser.add_argument("--lr", type=float, default=1e-3, choices=["cora", "citeseer"],
help="learning rate") help="dataset",
parser.add_argument("--n-epochs", type=int, default=200, )
help="number of training epochs") parser.add_argument(
parser.add_argument("--n-hidden", type=int, default=200, "--dropout", type=float, default=0.0, help="dropout probability"
help="number of hidden gcn units") )
parser.add_argument("--weight-decay", type=float, default=5e-4, parser.add_argument("--gpu", type=int, default=0, help="gpu")
help="Weight for L2 loss") parser.add_argument(
"--removed-class",
type=int,
nargs="*",
default=[0, 1, 2],
help="remove the unseen classes",
)
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--n-epochs", type=int, default=200, help="number of training epochs"
)
parser.add_argument(
"--n-hidden", type=int, default=200, help="number of hidden gcn units"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout): def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
...@@ -15,7 +17,7 @@ class GCN(nn.Module): ...@@ -15,7 +17,7 @@ class GCN(nn.Module):
h = self.dropout(h) h = self.dropout(h)
preds = self.gcn_2(self.g, h) preds = self.gcn_2(self.g, h)
return preds return preds
def embed(self, inputs): def embed(self, inputs):
h_1 = self.gcn_1(self.g, inputs) h_1 = self.gcn_1(self.g, inputs)
return h_1.detach() return h_1.detach()
...@@ -29,14 +31,14 @@ class RECT_L(nn.Module): ...@@ -29,14 +31,14 @@ class RECT_L(nn.Module):
self.fc = nn.Linear(n_hidden, in_feats) self.fc = nn.Linear(n_hidden, in_feats)
self.dropout = dropout self.dropout = dropout
nn.init.xavier_uniform_(self.fc.weight.data) nn.init.xavier_uniform_(self.fc.weight.data)
def forward(self, inputs): def forward(self, inputs):
h_1 = self.gcn_1(self.g, inputs) h_1 = self.gcn_1(self.g, inputs)
h_1 = F.dropout(h_1, p=self.dropout, training=self.training) h_1 = F.dropout(h_1, p=self.dropout, training=self.training)
preds = self.fc(h_1) preds = self.fc(h_1)
return preds return preds
# Detach the return variables # Detach the return variables
def embed(self, inputs): def embed(self, inputs):
h_1 = self.gcn_1(self.g, inputs) h_1 = self.gcn_1(self.g, inputs)
return h_1.detach() return h_1.detach()
\ No newline at end of file
import torch import torch
import dgl import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def load_data(args): def load_data(args):
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
else: else:
cuda = True cuda = True
g = g.int().to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
return g, features, labels, train_mask, test_mask, data.num_classes, cuda return g, features, labels, train_mask, test_mask, data.num_classes, cuda
def svd_feature(features, d=200): def svd_feature(features, d=200):
''' Get 200-dimensional node features, to avoid curse of dimensionality """Get 200-dimensional node features, to avoid curse of dimensionality"""
''' if features.shape[1] <= d:
if( features.shape[1] <= d ): return features return features
U, S, VT = torch.svd(features) U, S, VT = torch.svd(features)
res = torch.mm(U[:, 0:d], torch.diag(S[0:d])) res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))
return res return res
def process_classids(labels_temp): def process_classids(labels_temp):
''' Reorder the remaining classes with unseen classes removed. """Reorder the remaining classes with unseen classes removed.
Input: the label only removing unseen classes Input: the label only removing unseen classes
Output: the label with reordered classes Output: the label with reordered classes
''' """
labeldict = {} labeldict = {}
num=0 num = 0
for i in labels_temp: for i in labels_temp:
labeldict[int(i)]=1 labeldict[int(i)] = 1
labellist=sorted(labeldict) labellist = sorted(labeldict)
for label in labellist: for label in labellist:
labeldict[int(label)]=num labeldict[int(label)] = num
num=num+1 num = num + 1
for i in range(labels_temp.numel()): for i in range(labels_temp.numel()):
labels_temp[i]=labeldict[int(labels_temp[i])] labels_temp[i] = labeldict[int(labels_temp[i])]
return labels_temp return labels_temp
\ No newline at end of file
...@@ -3,24 +3,26 @@ Paper: https://arxiv.org/abs/1703.06103 ...@@ -3,24 +3,26 @@ Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn Reference Code: https://github.com/tkipf/relational-gcn
""" """
import argparse import argparse
import numpy as np
import time import time
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import EntityClassify from model import EntityClassify
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def main(args): def main(args):
# load graph data # load graph data
if args.dataset == 'aifb': if args.dataset == "aifb":
dataset = AIFBDataset() dataset = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
dataset = MUTAGDataset() dataset = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
dataset = BGSDataset() dataset = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
dataset = AMDataset() dataset = AMDataset()
else: else:
raise ValueError() raise ValueError()
...@@ -28,11 +30,11 @@ def main(args): ...@@ -28,11 +30,11 @@ def main(args):
g = dataset[0] g = dataset[0]
category = dataset.predict_category category = dataset.predict_category
num_classes = dataset.num_classes num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels') labels = g.nodes[category].data.pop("labels")
category_id = len(g.ntypes) category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes): for i, ntype in enumerate(g.ntypes):
if ntype == category: if ntype == category:
...@@ -40,8 +42,8 @@ def main(args): ...@@ -40,8 +42,8 @@ def main(args):
# split dataset into train, validate, test # split dataset into train, validate, test
if args.validation: if args.validation:
val_idx = train_idx[:len(train_idx) // 5] val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:] train_idx = train_idx[len(train_idx) // 5 :]
else: else:
val_idx = train_idx val_idx = train_idx
...@@ -49,25 +51,29 @@ def main(args): ...@@ -49,25 +51,29 @@ def main(args):
use_cuda = args.gpu >= 0 and th.cuda.is_available() use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda: if use_cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
g = g.to('cuda:%d' % args.gpu) g = g.to("cuda:%d" % args.gpu)
labels = labels.cuda() labels = labels.cuda()
train_idx = train_idx.cuda() train_idx = train_idx.cuda()
test_idx = test_idx.cuda() test_idx = test_idx.cuda()
# create model # create model
model = EntityClassify(g, model = EntityClassify(
args.n_hidden, g,
num_classes, args.n_hidden,
num_bases=args.n_bases, num_classes,
num_hidden_layers=args.n_layers - 2, num_bases=args.n_bases,
dropout=args.dropout, num_hidden_layers=args.n_layers - 2,
use_self_loop=args.use_self_loop) dropout=args.dropout,
use_self_loop=args.use_self_loop,
)
if use_cuda: if use_cuda:
model.cuda() model.cuda()
# optimizer # optimizer
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm) optimizer = th.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.l2norm
)
# training loop # training loop
print("start training...") print("start training...")
...@@ -85,11 +91,23 @@ def main(args): ...@@ -85,11 +91,23 @@ def main(args):
if epoch > 5: if epoch > 5:
dur.append(t1 - t0) dur.append(t1 - t0)
train_acc = th.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx) train_acc = th.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx]) val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx) val_acc = th.sum(
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}". logits[val_idx].argmax(dim=1) == labels[val_idx]
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur))) ).item() / len(val_idx)
print(
"Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
epoch,
train_acc,
loss.item(),
val_acc,
val_loss.item(),
np.average(dur),
)
)
print() print()
if args.model_path is not None: if args.model_path is not None:
th.save(model.state_dict(), args.model_path) th.save(model.state_dict(), args.model_path)
...@@ -97,37 +115,59 @@ def main(args): ...@@ -97,37 +115,59 @@ def main(args):
model.eval() model.eval()
logits = model.forward()[category] logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = th.sum(
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print() print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') if __name__ == "__main__":
parser.add_argument("--dropout", type=float, default=0, parser = argparse.ArgumentParser(description="RGCN")
help="dropout probability") parser.add_argument(
parser.add_argument("--n-hidden", type=int, default=16, "--dropout", type=float, default=0, help="dropout probability"
help="number of hidden units") )
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument(
help="gpu") "--n-hidden", type=int, default=16, help="number of hidden units"
parser.add_argument("--lr", type=float, default=1e-2, )
help="learning rate") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--n-bases", type=int, default=-1, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="number of filter weight matrices, default: -1 [use all]") parser.add_argument(
parser.add_argument("--n-layers", type=int, default=2, "--n-bases",
help="number of propagation rounds") type=int,
parser.add_argument("-e", "--n-epochs", type=int, default=50, default=-1,
help="number of training epochs") help="number of filter weight matrices, default: -1 [use all]",
parser.add_argument("-d", "--dataset", type=str, required=True, )
help="dataset to use") parser.add_argument(
parser.add_argument("--model_path", type=str, default=None, "--n-layers", type=int, default=2, help="number of propagation rounds"
help='path for save the model') )
parser.add_argument("--l2norm", type=float, default=0, parser.add_argument(
help="l2 norm coef") "-e",
parser.add_argument("--use-self-loop", default=False, action='store_true', "--n-epochs",
help="include self feature as a special relation") type=int,
default=50,
help="number of training epochs",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, default=None, help="path for save the model"
)
parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument('--testing', dest='validation', action='store_false') fp.add_argument("--testing", dest="validation", action="store_false")
parser.set_defaults(validation=True) parser.set_defaults(validation=True)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment