Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler
'''
"""
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
def square_distance(src, dst):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(
device).view(view_shape).repeat(repeat_shape)
batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
class KNearNeighbors(nn.Module):
'''
"""
Find the k nearest neighbors
'''
"""
def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor]
group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
return group_idx
class KNNGraphBuilder(nn.Module):
'''
"""
Build NN graph
'''
"""
def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__()
......@@ -76,46 +81,52 @@ class KNNGraphBuilder(nn.Module):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor,
src.shape[0] // centroids.shape[1])).view(-1)
dst = (
centroids[i]
.view(-1, 1)
.repeat(
1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
)
.view(-1)
)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
g.ndata["pos"] = pos[i][uniq]
g.ndata["center"] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
g.ndata["feat"] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
'''
"""
Compute the input feature from neighbors
'''
"""
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos']
if 'feat' in edges.src:
res = torch.cat([pos, edges.src['feat']], 1)
pos = edges.src["pos"] - edges.dst["pos"]
if "feat" in edges.src:
res = torch.cat([pos, edges.src["feat"]], 1)
else:
res = pos
return {'agg_feat': res}
return {"agg_feat": res}
class KNNConv(nn.Module):
'''
"""
Feature aggregation
'''
"""
def __init__(self, sizes, batch_size):
super(KNNConv, self).__init__()
......@@ -123,13 +134,16 @@ class KNNConv(nn.Module):
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1))
self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(
self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1)
shape = nodes.mailbox["agg_feat"].shape
h = (
nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
......@@ -137,12 +151,12 @@ class KNNConv(nn.Module):
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
return {"new_feat": h}
def group_all(self, pos, feat):
'''
"""
Feature aggregation and pooling for the non-sampling layer
'''
"""
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
......@@ -177,12 +191,11 @@ class TransitionDown(nn.Module):
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(
self.batch_size, -1, feat_dim)
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res
......@@ -198,7 +211,7 @@ class FeaturePropagation(nn.Module):
sizes = [input_dims] + sizes
for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1))
self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2):
......@@ -225,8 +238,9 @@ class FeaturePropagation(nn.Module):
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feat = torch.sum(index_points(
feat2, idx) * weight.view(B, N, 3, 1), dim=2)
interpolated_feat = torch.sum(
index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
......
import numpy as np
import torch
from helper import TransitionDown, TransitionUp, index_points, square_distance
from torch import nn
import numpy as np
from helper import square_distance, index_points, TransitionDown, TransitionUp
'''
"""
Part of the code are adapted from
https://github.com/qq456cvb/Point-Transformers
'''
"""
class PointTransformerBlock(nn.Module):
......@@ -21,12 +19,12 @@ class PointTransformerBlock(nn.Module):
self.fc_delta = nn.Sequential(
nn.Linear(3, transformer_dim),
nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim)
nn.Linear(transformer_dim, transformer_dim),
)
self.fc_gamma = nn.Sequential(
nn.Linear(transformer_dim, transformer_dim),
nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim)
nn.Linear(transformer_dim, transformer_dim),
)
self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False)
......@@ -35,43 +33,71 @@ class PointTransformerBlock(nn.Module):
def forward(self, x, pos):
dists = square_distance(pos, pos)
knn_idx = dists.argsort()[:, :, :self.n_neighbors] # b x n x k
knn_idx = dists.argsort()[:, :, : self.n_neighbors] # b x n x k
knn_pos = index_points(pos, knn_idx)
h = self.fc1(x)
q, k, v = self.w_qs(h), index_points(
self.w_ks(h), knn_idx), index_points(self.w_vs(h), knn_idx)
q, k, v = (
self.w_qs(h),
index_points(self.w_ks(h), knn_idx),
index_points(self.w_vs(h), knn_idx),
)
pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f
attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
attn = torch.softmax(attn / np.sqrt(k.size(-1)),
dim=-2) # b x n x k x f
attn = torch.softmax(
attn / np.sqrt(k.size(-1)), dim=-2
) # b x n x k x f
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
res = torch.einsum("bmnf,bmnf->bmf", attn, v + pos_enc)
res = self.fc2(res) + x
return res, attn
class PointTransformer(nn.Module):
def __init__(self, n_points, batch_size, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16):
def __init__(
self,
n_points,
batch_size,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformer, self).__init__()
self.fc = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
nn.Linear(hidden_dim, hidden_dim),
)
self.ptb = PointTransformerBlock(
hidden_dim, n_neighbors, transformer_dim)
hidden_dim, n_neighbors, transformer_dim
)
self.transition_downs = nn.ModuleList()
self.transformers = nn.ModuleList()
for i in range(n_blocks):
block_hidden_dim = hidden_dim * 2 ** (i + 1)
block_n_points = n_points // (downsampling_rate ** (i + 1))
self.transition_downs.append(TransitionDown(block_n_points, batch_size, [
block_hidden_dim // 2 + 3, block_hidden_dim, block_hidden_dim], n_neighbors=n_neighbors))
self.transition_downs.append(
TransitionDown(
block_n_points,
batch_size,
[
block_hidden_dim // 2 + 3,
block_hidden_dim,
block_hidden_dim,
],
n_neighbors=n_neighbors,
)
)
self.transformers.append(
PointTransformerBlock(block_hidden_dim, n_neighbors, transformer_dim))
PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
def forward(self, x):
if x.shape[-1] > 3:
......@@ -93,16 +119,35 @@ class PointTransformer(nn.Module):
class PointTransformerCLS(nn.Module):
def __init__(self, out_classes, batch_size, n_points=1024, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16):
def __init__(
self,
out_classes,
batch_size,
n_points=1024,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super(PointTransformerCLS, self).__init__()
self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors)
n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.out = self.fc2 = nn.Sequential(
nn.Linear(hidden_dim * 2 ** (n_blocks), 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, out_classes)
nn.Linear(64, out_classes),
)
def forward(self, x):
......@@ -112,37 +157,63 @@ class PointTransformerCLS(nn.Module):
class PointTransformerSeg(nn.Module):
def __init__(self, out_classes, batch_size, n_points=2048, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16):
def __init__(
self,
out_classes,
batch_size,
n_points=2048,
feature_dim=3,
n_blocks=4,
downsampling_rate=4,
hidden_dim=32,
transformer_dim=None,
n_neighbors=16,
):
super().__init__()
self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors)
n_points,
batch_size,
feature_dim,
n_blocks,
downsampling_rate,
hidden_dim,
transformer_dim,
n_neighbors,
)
self.fc = nn.Sequential(
nn.Linear(32 * 2 ** n_blocks, 512),
nn.Linear(32 * 2**n_blocks, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 32 * 2 ** n_blocks)
nn.Linear(512, 32 * 2**n_blocks),
)
self.ptb = PointTransformerBlock(
32 * 2 ** n_blocks, n_neighbors, transformer_dim)
32 * 2**n_blocks, n_neighbors, transformer_dim
)
self.n_blocks = n_blocks
self.transition_ups = nn.ModuleList()
self.transformers = nn.ModuleList()
for i in reversed(range(n_blocks)):
block_hidden_dim = 32 * 2 ** i
block_hidden_dim = 32 * 2**i
self.transition_ups.append(
TransitionUp(block_hidden_dim * 2, block_hidden_dim, block_hidden_dim))
self.transformers.append(PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim))
TransitionUp(
block_hidden_dim * 2, block_hidden_dim, block_hidden_dim
)
)
self.transformers.append(
PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim
)
)
self.out = nn.Sequential(
nn.Linear(32+16, 64),
nn.Linear(32 + 16, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, out_classes)
nn.Linear(64, out_classes),
)
def forward(self, x, cat_vec=None):
......@@ -152,8 +223,9 @@ class PointTransformerSeg(nn.Module):
for i in range(self.n_blocks):
h = self.transition_ups[i](
pos, h, hidden_state[- i - 2][0], hidden_state[- i - 2][1])
pos = hidden_state[- i - 2][0]
pos, h, hidden_state[-i - 2][0], hidden_state[-i - 2][1]
)
pos = hidden_state[-i - 2][0]
h, _ = self.transformers[i](h, pos)
return self.out(torch.cat([h, cat_vec], dim=-1))
......
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
'''
"""
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
"""Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
......@@ -16,14 +17,14 @@ def normalize_data(batch_data):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
"""Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
......@@ -34,8 +35,9 @@ def shuffle_data(data, labels):
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
"""Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
......@@ -44,10 +46,11 @@ def shuffle_points(batch_data):
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
return batch_data[:, idx, :]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
"""Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
......@@ -59,15 +62,18 @@ def rotate_point_cloud(batch_data):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
"""Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
......@@ -79,35 +85,45 @@ def rotate_point_cloud_z(batch_data):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
rotation_matrix = np.array(
[[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
)
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
"""Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
"""
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
def rotate_perturbation_point_cloud_with_normal(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
......@@ -115,26 +131,40 @@ def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, an
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
angles = np.clip(
angle_sigma * np.random.randn(3), -angle_clip, angle_clip
)
Rx = np.array(
[
[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])],
]
)
Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
"""Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
......@@ -142,18 +172,21 @@ def rotate_point_cloud_by_angle(batch_data, rotation_angle):
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
# rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, :, 0:3]
rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
"""Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
......@@ -162,22 +195,27 @@ def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
# rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
def rotate_perturbation_point_cloud(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
......@@ -185,51 +223,66 @@ def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.1
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
angles = np.clip(
angle_sigma * np.random.randn(3), -angle_clip, angle_clip
)
Rx = np.array(
[
[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])],
]
)
Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
"""Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
"""Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
"""Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
......@@ -238,15 +291,22 @@ def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
"""batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(
np.random.random((batch_pc.shape[1])) <= dropout_ratio
)[0]
if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc
from point_transformer import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import tqdm
import time
from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader
import torch.nn as nn
import provider
import torch
import time
import torch.nn as nn
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from point_transformer import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--opt", type=str, default="adam")
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
download(
"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
......@@ -43,7 +50,8 @@ CustomDataLoader = partial(
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
......@@ -60,8 +68,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
......@@ -84,11 +91,19 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, time.time() - start_time))
tq.set_postfix(
{
"AvgLoss": "%.5f" % (total_loss / num_batches),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches,
total_correct / count,
time.time() - start_time,
)
)
scheduler.step()
......@@ -111,10 +126,12 @@ def evaluate(net, test_loader, dev):
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time))
tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
print(
"[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time
)
)
return total_correct / count
......@@ -125,13 +142,15 @@ net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'sgd':
if args.opt == "sgd":
# The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01,
momentum=0.9, weight_decay=1e-4)
opt = torch.optim.SGD(
net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1)
elif args.opt == 'adam':
opt, milestones=[120, 160], gamma=0.1
)
elif args.opt == "adam":
# The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam(
......@@ -139,16 +158,26 @@ elif args.opt == 'adam':
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=1e-4
weight_decay=1e-4,
)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0
......@@ -161,6 +190,5 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
print()
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from functools import partial
import tqdm
import argparse
import time
from functools import partial
import numpy as np
import torch
import torch.optim as optim
import tqdm
from point_transformer import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet
from point_transformer import PointTransformerSeg, PartSegLoss
from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--tensorboard', action='store_true')
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--tensorboard", action="store_true")
parser.add_argument("--opt", type=str, default="adam")
args = parser.parse_args()
num_workers = args.num_workers
......@@ -37,7 +37,8 @@ CustomDataLoader = partial(
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
......@@ -58,8 +59,11 @@ def train(net, opt, scheduler, train_loader, dev):
opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = (
torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16)
logits = net(data, cat_tensor).permute(0, 2, 1)
loss = L(logits, label)
......@@ -78,14 +82,17 @@ def train(net, opt, scheduler, train_loader, dev):
AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count
tq.set_postfix({
'AvgLoss': '%.5f' % AvgLoss,
'AvgAcc': '%.5f' % AvgAcc})
tq.set_postfix(
{"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
)
scheduler.step()
end = time.time()
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, end - start))
return data, preds, AvgLoss, AvgAcc, end-start
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches, total_correct / count, end - start
)
)
return data, preds, AvgLoss, AvgAcc, end - start
def mIoU(preds, label, cat, cat_miou, seg_classes):
......@@ -128,27 +135,39 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(
num_examples, -1, 16)
cat_tensor = (
torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16)
logits = net(data, cat_tensor).permute(0, 2, 1)
_, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(),
cat_miou = mIoU(
preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat, cat_miou, shapenet.seg_classes)
cat,
cat_miou,
shapenet.seg_classes,
)
for _, v in cat_miou.items():
if v[1] > 0:
miou += v[0]
count += v[1]
per_cat_miou += v[0] / v[1]
per_cat_count += 1
tq.set_postfix({
'mIoU': '%.5f' % (miou / count),
'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)})
print("[Test] mIoU: %.5f, per Category mIoU: %.5f" %
(miou / count, per_cat_miou / per_cat_count))
tq.set_postfix(
{
"mIoU": "%.5f" % (miou / count),
"per Category mIoU": "%.5f"
% (per_cat_miou / per_cat_count),
}
)
print(
"[Test] mIoU: %.5f, per Category mIoU: %.5f"
% (miou / count, per_cat_miou / per_cat_count)
)
if per_cat_verbose:
print("-" * 60)
print("Per-Category mIoU:")
......@@ -168,13 +187,15 @@ net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'sgd':
if args.opt == "sgd":
# The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01,
momentum=0.9, weight_decay=1e-4)
opt = torch.optim.SGD(
net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1)
elif args.opt == 'adam':
opt, milestones=[120, 160], gamma=0.1
)
elif args.opt == "adam":
# The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam(
......@@ -182,7 +203,7 @@ elif args.opt == 'adam':
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=1e-4
weight_decay=1e-4,
)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
......@@ -198,20 +219,63 @@ if args.tensorboard:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
# Select 50 distinct colors for different parts
color_map = torch.tensor([
[47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [
128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113],
[205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [
0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139],
[10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [
255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122],
[220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [
218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237],
[128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [
175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180],
])
color_map = torch.tensor(
[
[47, 79, 79],
[139, 69, 19],
[112, 128, 144],
[85, 107, 47],
[139, 0, 0],
[128, 128, 0],
[72, 61, 139],
[0, 128, 0],
[188, 143, 143],
[60, 179, 113],
[205, 133, 63],
[0, 139, 139],
[70, 130, 180],
[205, 92, 92],
[154, 205, 50],
[0, 0, 139],
[50, 205, 50],
[250, 250, 250],
[218, 165, 32],
[139, 0, 139],
[10, 10, 10],
[176, 48, 96],
[72, 209, 204],
[153, 50, 204],
[255, 69, 0],
[255, 145, 0],
[0, 0, 205],
[255, 255, 0],
[0, 255, 0],
[233, 150, 122],
[220, 20, 60],
[0, 191, 255],
[160, 32, 240],
[192, 192, 192],
[173, 255, 47],
[218, 112, 214],
[216, 191, 216],
[255, 127, 80],
[255, 0, 255],
[100, 149, 237],
[128, 128, 128],
[221, 160, 221],
[144, 238, 144],
[123, 104, 238],
[255, 160, 122],
[175, 238, 238],
[238, 130, 238],
[127, 255, 212],
[255, 218, 185],
[255, 105, 180],
]
)
# paint each point according to its pred
......@@ -227,28 +291,38 @@ best_test_per_cat_miou = 0
for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch))
data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev)
net, opt, scheduler, train_loader, dev
)
if (epoch + 1) % 5 == 0 or epoch == 0:
test_miou, test_per_cat_miou = evaluate(
net, test_loader, dev, True)
test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)
if test_miou > best_test_miou:
best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
print(
"Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
% (
test_miou,
best_test_miou,
test_per_cat_miou,
best_test_per_cat_miou,
)
)
# Tensorboard
if args.tensorboard:
colored = paint(preds)
writer.add_mesh('data', vertices=data,
colors=colored, global_step=epoch)
writer.add_scalar('training time for one epoch',
training_time, global_step=epoch)
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch)
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch)
writer.add_mesh(
"data", vertices=data, colors=colored, global_step=epoch
)
writer.add_scalar(
"training time for one epoch", training_time, global_step=epoch
)
writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch)
writer.add_scalar('best test mIoU',
best_test_miou, global_step=epoch)
writer.add_scalar("test mIoU", test_miou, global_step=epoch)
writer.add_scalar(
"best test mIoU", best_test_miou, global_step=epoch
)
print()
import numpy as np
import warnings
import os
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
......@@ -11,6 +14,7 @@ def pc_normalize(pc):
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
......@@ -25,7 +29,7 @@ def farthest_point_sample(point, npoint):
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
......@@ -39,9 +43,17 @@ def farthest_point_sample(point, npoint):
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
def __init__(
self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
"""
Input:
root: the root path to the local data files
......@@ -54,22 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
shape_ids["train"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d'%(split,len(self.datapath)))
self.datapath = [
(
shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
......@@ -84,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints,:]
point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import os, json, tqdm
import numpy as np
import dgl
import json
import os
from zipfile import ZipFile
from torch.utils.data import Dataset
import numpy as np
import tqdm
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points
......@@ -13,8 +18,13 @@ class ShapeNet(object):
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
data_filename = (
"shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
)
data_path = os.path.join(
download_path,
"shapenetcore_partanno_segmentation_benchmark_v0_normal",
)
if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path):
......@@ -24,52 +34,72 @@ class ShapeNet(object):
synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
self.synset_dict = {}
for syn in synset:
self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3],
'Bag': [4, 5],
'Cap': [6, 7],
'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15],
'Earphone': [16, 17, 18],
'Guitar': [19, 20, 21],
'Knife': [22, 23],
'Lamp': [24, 25, 26, 27],
'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35],
'Mug': [36, 37],
'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43],
'Skateboard': [44, 45, 46],
'Table': [47, 48, 49]}
train_split_json = 'shuffled_train_file_list.json'
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json'
split_path = os.path.join(data_path, 'train_test_split')
self.seg_classes = {
"Airplane": [0, 1, 2, 3],
"Bag": [4, 5],
"Cap": [6, 7],
"Car": [8, 9, 10, 11],
"Chair": [12, 13, 14, 15],
"Earphone": [16, 17, 18],
"Guitar": [19, 20, 21],
"Knife": [22, 23],
"Lamp": [24, 25, 26, 27],
"Laptop": [28, 29],
"Motorbike": [30, 31, 32, 33, 34, 35],
"Mug": [36, 37],
"Pistol": [38, 39, 40],
"Rocket": [41, 42, 43],
"Skateboard": [44, 45, 46],
"Table": [47, 48, 49],
}
train_split_json = "shuffled_train_file_list.json"
val_split_json = "shuffled_val_file_list.json"
test_split_json = "shuffled_test_file_list.json"
split_path = os.path.join(data_path, "train_test_split")
with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.train_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.val_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.test_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "train", self.num_points, self.normal_channel
)
def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "valid", self.num_points, self.normal_channel
)
def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "trainval", self.num_points, self.normal_channel
)
def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "test", self.num_points, self.normal_channel
)
class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True):
......@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset):
else:
self.dim = 6
if mode == 'train':
if mode == "train":
self.file_list = shapenet.train_file_list
elif mode == 'valid':
elif mode == "valid":
self.file_list = shapenet.val_file_list
elif mode == 'test':
elif mode == "test":
self.file_list = shapenet.test_file_list
elif mode == 'trainval':
elif mode == "trainval":
self.file_list = shapenet.train_file_list + shapenet.val_file_list
else:
raise "Not supported `mode`"
......@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset):
data_list = []
label_list = []
category_list = []
print('Loading data from split ' + self.mode)
print("Loading data from split " + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float)
data_list.append(data[:, 0:self.dim])
data = np.array(
[t.split("\n")[0].split(" ") for t in f.readlines()]
).astype(np.float)
data_list.append(data[:, 0 : self.dim])
label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
self.data = data_list
self.label = label_list
self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x
def __len__(self):
return len(self.data)
def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
x = self.data[i][inds,:self.dim]
inds = np.random.choice(
self.data[i].shape[0], self.num_points, replace=True
)
x = self.data[i][inds, : self.dim]
y = self.label[i][inds]
cat = self.category[i]
if self.mode == 'train':
if self.mode == "train":
x = self.translate(x, size=self.dim)
x = x.astype(np.float)
y = y.astype(np.int)
......
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import dgl
import dgl.function as fn
from dgl.geometry import farthest_point_sampler # dgl.geometry.pytorch -> dgl.geometry
from dgl.geometry import (
farthest_point_sampler,
) # dgl.geometry.pytorch -> dgl.geometry
'''
"""
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
def square_distance(src, dst):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
class FixedRadiusNearNeighbors(nn.Module):
'''
"""
Ball Query - Find the neighbors with-in a fixed radius
'''
"""
def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
device = pos.device
B, N, _ = pos.shape
center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
group_idx = (
torch.arange(N, dtype=torch.long)
.to(device)
.view(1, 1, N)
.repeat([B, S, 1])
)
sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :self.n_neighbor]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
group_idx[sqrdists > self.radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
group_first = (
group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
)
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
class FixedRadiusNNGraph(nn.Module):
'''
"""
Build NN graph
'''
"""
def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__()
self.radius = radius
......@@ -86,50 +107,58 @@ class FixedRadiusNNGraph(nn.Module):
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
g.ndata["pos"] = pos[i][uniq]
g.ndata["center"] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
g.ndata["feat"] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
'''
"""
Compute the input feature from neighbors
'''
"""
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos']
if 'feat' in edges.src:
res = torch.cat([pos, edges.src['feat']], 1)
pos = edges.src["pos"] - edges.dst["pos"]
if "feat" in edges.src:
res = torch.cat([pos, edges.src["feat"]], 1)
else:
res = pos
return {'agg_feat': res}
return {"agg_feat": res}
class PointNetConv(nn.Module):
'''
"""
Feature aggregation
'''
"""
def __init__(self, sizes, batch_size):
super(PointNetConv, self).__init__()
self.batch_size = batch_size
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1))
self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1)
shape = nodes.mailbox["agg_feat"].shape
h = (
nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
......@@ -137,12 +166,12 @@ class PointNetConv(nn.Module):
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
return {"new_feat": h}
def group_all(self, pos, feat):
'''
"""
Feature aggregation and pooling for the non-sampling layer
'''
"""
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
......@@ -158,12 +187,21 @@ class PointNetConv(nn.Module):
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h
class SAModule(nn.Module):
"""
The Set Abstraction Layer
"""
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
group_all=False):
def __init__(
self,
npoints,
batch_size,
radius,
mlp_sizes,
n_neighbor=64,
group_all=False,
):
super(SAModule, self).__init__()
self.group_all = group_all
if not group_all:
......@@ -181,18 +219,22 @@ class SAModule(nn.Module):
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res
class SAMSGModule(nn.Module):
"""
The Set Abstraction Multi-Scale grouping Layer
"""
def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list):
def __init__(
self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list
):
super(SAMSGModule, self).__init__()
self.batch_size = batch_size
self.group_size = len(radius_list)
......@@ -202,9 +244,12 @@ class SAMSGModule(nn.Module):
self.message_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for i in range(self.group_size):
self.frnn_graph_list.append(FixedRadiusNNGraph(radius_list[i],
n_neighbor_list[i]))
self.message_list.append(RelativePositionMessage(n_neighbor_list[i]))
self.frnn_graph_list.append(
FixedRadiusNNGraph(radius_list[i], n_neighbor_list[i])
)
self.message_list.append(
RelativePositionMessage(n_neighbor_list[i])
)
self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))
def forward(self, pos, feat):
......@@ -214,21 +259,27 @@ class SAMSGModule(nn.Module):
for i in range(self.group_size):
g = self.frnn_graph_list[i](pos, centroids, feat)
g.update_all(self.message_list[i], self.conv_list[i])
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
if i == 0:
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
pos_res = g.ndata["pos"][mask].view(
self.batch_size, -1, pos_dim
)
feat_res = g.ndata["new_feat"][mask].view(
self.batch_size, -1, feat_dim
)
feat_res_list.append(feat_res)
feat_res = torch.cat(feat_res_list, 2)
return pos_res, feat_res
class PointNet2FP(nn.Module):
"""
The Feature Propagation Layer
"""
def __init__(self, input_dims, sizes):
super(PointNet2FP, self).__init__()
self.convs = nn.ModuleList()
......@@ -236,7 +287,7 @@ class PointNet2FP(nn.Module):
sizes = [input_dims] + sizes
for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1))
self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2):
......@@ -263,7 +314,9 @@ class PointNet2FP(nn.Module):
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feat = torch.sum(index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2)
interpolated_feat = torch.sum(
index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
......@@ -278,14 +331,21 @@ class PointNet2FP(nn.Module):
class PointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(PointNet2SSGCls, self).__init__()
self.input_dims = input_dims
self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128])
self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
group_all=True)
self.sa_module1 = SAModule(
512, batch_size, 0.2, [input_dims, 64, 64, 128]
)
self.sa_module2 = SAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = SAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
......@@ -320,19 +380,39 @@ class PointNet2SSGCls(nn.Module):
out = self.mlp_out(h)
return out
class PointNet2MSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(PointNet2MSGCls, self).__init__()
self.input_dims = input_dims
self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [16, 32, 128],
[[input_dims, 32, 32, 64], [input_dims, 64, 64, 128],
[input_dims, 64, 96, 128]])
self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.2, 0.4, 0.8], [32, 64, 128],
[[320 + 3, 64, 64, 128], [320 + 3, 128, 128, 256],
[320 + 3, 128, 128, 256]])
self.sa_module3 = SAModule(None, batch_size, None, [640 + 3, 256, 512, 1024],
group_all=True)
self.sa_msg_module1 = SAMSGModule(
512,
batch_size,
[0.1, 0.2, 0.4],
[16, 32, 128],
[
[input_dims, 32, 32, 64],
[input_dims, 64, 64, 128],
[input_dims, 64, 96, 128],
],
)
self.sa_msg_module2 = SAMSGModule(
128,
batch_size,
[0.2, 0.4, 0.8],
[32, 64, 128],
[
[320 + 3, 64, 64, 128],
[320 + 3, 128, 128, 256],
[320 + 3, 128, 128, 256],
],
)
self.sa_module3 = SAModule(
None, batch_size, None, [640 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
......
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pointnet2 import PointNet2FP, SAModule, SAMSGModule
from torch.autograd import Variable
import numpy as np
from pointnet2 import SAModule, SAMSGModule, PointNet2FP
class PointNet2SSGPartSeg(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=6):
super(PointNet2SSGPartSeg, self).__init__()
#if normal_channel == true, input_dims = 6+3
# if normal_channel == true, input_dims = 6+3
self.input_dims = input_dims
self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32)
self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
group_all=True)
self.sa_module1 = SAModule(
512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32
)
self.sa_module2 = SAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = SAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.fp3 = PointNet2FP(1280, [256, 256])
self.fp2 = PointNet2FP(384, [256, 128])
# if normal_channel == true, 128+16+6+3
self.fp1 = PointNet2FP(128+16+6, [128, 128, 128])
self.fp1 = PointNet2FP(128 + 16 + 6, [128, 128, 128])
self.conv1 = nn.Conv1d(128, 128, 1)
self.bn1 = nn.BatchNorm1d(128)
......@@ -38,7 +44,9 @@ class PointNet2SSGPartSeg(nn.Module):
l2_pos, l2_feat = self.sa_module2(l1_pos, l1_feat)
l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat) # [B, N, C], [B, D]
# Feature Propagation layers
l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1)) # l2_feat: [B, D, N]
l2_feat = self.fp3(
l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1)
) # l2_feat: [B, D, N]
l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))
l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)
l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))
......@@ -53,13 +61,30 @@ class PointNet2MSGPartSeg(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=6):
super(PointNet2MSGPartSeg, self).__init__()
self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [32, 64, 128],
[[input_dims, 32, 32, 64], [input_dims, 64, 64, 128],
[input_dims, 64, 96, 128]])
self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.4, 0.8], [64, 128],
[[128+128+64 +3, 128, 128, 256], [128+128+64 +3, 128, 196, 256]])
self.sa_module3 = SAModule(None, batch_size, None, [512 + 3, 256, 512, 1024],
group_all=True)
self.sa_msg_module1 = SAMSGModule(
512,
batch_size,
[0.1, 0.2, 0.4],
[32, 64, 128],
[
[input_dims, 32, 32, 64],
[input_dims, 64, 64, 128],
[input_dims, 64, 96, 128],
],
)
self.sa_msg_module2 = SAMSGModule(
128,
batch_size,
[0.4, 0.8],
[64, 128],
[
[128 + 128 + 64 + 3, 128, 128, 256],
[128 + 128 + 64 + 3, 128, 196, 256],
],
)
self.sa_module3 = SAModule(
None, batch_size, None, [512 + 3, 256, 512, 1024], group_all=True
)
self.fp3 = PointNet2FP(1536, [256, 256])
self.fp2 = PointNet2FP(576, [256, 128])
......
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class PointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64,
dropout_prob=0.5, use_transform=True):
def __init__(
self,
output_classes,
input_dims=3,
conv1_dim=64,
dropout_prob=0.5,
use_transform=True,
):
super(PointNetCls, self).__init__()
self.input_dims = input_dims
self.conv1 = nn.ModuleList()
......@@ -85,6 +92,7 @@ class PointNetCls(nn.Module):
out = self.mlp_out(h)
return out
class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64):
super(TransformNet, self).__init__()
......@@ -127,8 +135,14 @@ class TransformNet(nn.Module):
out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
iden = Variable(
torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda:
iden = iden.cuda()
out = out + iden
......
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class PointNetPartSeg(nn.Module):
def __init__(self, output_classes, input_dims=3, num_points=2048,
use_transform=True):
def __init__(
self, output_classes, input_dims=3, num_points=2048, use_transform=True
):
super(PointNetPartSeg, self).__init__()
self.input_dims = input_dims
......@@ -33,7 +35,7 @@ class PointNetPartSeg(nn.Module):
self.pool_feat_len = 2048
self.conv3 = nn.ModuleList()
self.conv3.append(nn.Conv1d(2048 + 64 + 128*3 + 512 + 16, 256, 1))
self.conv3.append(nn.Conv1d(2048 + 64 + 128 * 3 + 512 + 16, 256, 1))
self.conv3.append(nn.Conv1d(256, 256, 1))
self.conv3.append(nn.Conv1d(256, 128, 1))
......@@ -98,6 +100,7 @@ class PointNetPartSeg(nn.Module):
out = self.conv_out(h)
return out
class TransformNet(nn.Module):
def __init__(self, input_dims=3, num_points=2048):
super(TransformNet, self).__init__()
......@@ -140,14 +143,21 @@ class TransformNet(nn.Module):
out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
iden = Variable(
torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda:
iden = iden.cuda()
out = out + iden
out = out.view(-1, self.input_dims, self.input_dims)
return out
class PartSegLoss(nn.Module):
def __init__(self, eps=0.2):
super(PartSegLoss, self).__init__()
......
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
'''
"""
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
"""Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
......@@ -16,14 +17,14 @@ def normalize_data(batch_data):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
"""Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
......@@ -34,8 +35,9 @@ def shuffle_data(data, labels):
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
"""Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
......@@ -44,10 +46,11 @@ def shuffle_points(batch_data):
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
return batch_data[:, idx, :]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
"""Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
......@@ -59,15 +62,18 @@ def rotate_point_cloud(batch_data):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
"""Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
......@@ -79,35 +85,45 @@ def rotate_point_cloud_z(batch_data):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
rotation_matrix = np.array(
[[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
)
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
"""Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
"""
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
def rotate_perturbation_point_cloud_with_normal(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
......@@ -115,26 +131,40 @@ def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, an
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
angles = np.clip(
angle_sigma * np.random.randn(3), -angle_clip, angle_clip
)
Rx = np.array(
[
[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])],
]
)
Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
"""Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
......@@ -142,18 +172,21 @@ def rotate_point_cloud_by_angle(batch_data, rotation_angle):
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
# rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, :, 0:3]
rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
"""Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
......@@ -162,22 +195,27 @@ def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
# rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
def rotate_perturbation_point_cloud(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
......@@ -185,51 +223,66 @@ def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.1
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
angles = np.clip(
angle_sigma * np.random.randn(3), -angle_clip, angle_clip
)
Rx = np.array(
[
[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])],
]
)
Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
"""Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
"""Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
"""Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
......@@ -238,15 +291,22 @@ def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
"""batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(
np.random.random((batch_pc.shape[1])) <= dropout_ratio
)[0]
if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc
from pointnet2 import PointNet2SSGCls, PointNet2MSGCls
from pointnet_cls import PointNetCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import urllib
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import provider
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from pointnet2 import PointNet2MSGCls, PointNet2SSGCls
from pointnet_cls import PointNetCls
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False
# from dataset import ModelNet
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet')
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument("--model", type=str, default="pointnet")
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
download(
"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
......@@ -49,7 +57,8 @@ CustomDataLoader = partial(
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
......@@ -65,8 +74,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
......@@ -89,9 +97,12 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
tq.set_postfix(
{
"AvgLoss": "%.5f" % (total_loss / num_batches),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
scheduler.step()
......@@ -114,19 +125,18 @@ def evaluate(net, test_loader, dev):
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'pointnet':
if args.model == "pointnet":
net = PointNetCls(40, input_dims=6)
elif args.model == 'pointnet2_ssg':
elif args.model == "pointnet2_ssg":
net = PointNet2SSGCls(40, batch_size, input_dims=6)
elif args.model == 'pointnet2_msg':
elif args.model == "pointnet2_msg":
net = PointNet2MSGCls(40, batch_size, input_dims=6)
net = net.to(dev)
......@@ -137,23 +147,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0
for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0:
print('Epoch #%d Testing' % epoch)
print("Epoch #%d Testing" % epoch)
test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc:
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
import argparse
import os
import time
import urllib
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
from pointnet_partseg import PartSegLoss, PointNetPartSeg
from ShapeNet import ShapeNet
from torch.utils.data import DataLoader
import numpy as np
import dgl
from dgl.data.utils import download, get_download_dir
from functools import partial
import tqdm
import urllib
import os
import argparse
import time
from ShapeNet import ShapeNet
from pointnet_partseg import PointNetPartSeg, PartSegLoss
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet')
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--tensorboard', action='store_true')
parser.add_argument("--model", type=str, default="pointnet")
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--tensorboard", action="store_true")
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
def collate(samples):
graphs, cat = map(list, zip(*samples))
return dgl.batch(graphs), cat
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
category_list = sorted(list(shapenet.seg_classes.keys()))
......@@ -61,8 +65,12 @@ def train(net, opt, scheduler, train_loader, dev):
opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1)
cat_tensor = (
torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0, 2, 1)
logits = net(data, cat_tensor)
loss = L(logits, label)
loss.backward()
......@@ -80,20 +88,21 @@ def train(net, opt, scheduler, train_loader, dev):
AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count
tq.set_postfix({
'AvgLoss': '%.5f' % AvgLoss,
'AvgAcc': '%.5f' % AvgAcc})
tq.set_postfix(
{"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
)
scheduler.step()
end = time.time()
return data, preds, AvgLoss, AvgAcc, end-start
return data, preds, AvgLoss, AvgAcc, end - start
def mIoU(preds, label, cat, cat_miou, seg_classes):
for i in range(preds.shape[0]):
shape_iou = 0
n = len(seg_classes[cat[i]])
for cls in seg_classes[cat[i]]:
pred_set = set(np.where(preds[i,:] == cls)[0])
label_set = set(np.where(label[i,:] == cls)[0])
pred_set = set(np.where(preds[i, :] == cls)[0])
label_set = set(np.where(label[i, :] == cls)[0])
union = len(pred_set.union(label_set))
inter = len(pred_set.intersection(label_set))
if union == 0:
......@@ -106,6 +115,7 @@ def mIoU(preds, label, cat, cat_miou, seg_classes):
return cat_miou
def evaluate(net, test_loader, dev, per_cat_verbose=False):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
......@@ -126,23 +136,36 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1)
cat_tensor = (
torch.tensor(eye_mat[cat_ind])
.to(dev, dtype=torch.float)
.repeat(1, 2048)
)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(
0, 2, 1
)
logits = net(data, cat_tensor)
_, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(),
cat_miou = mIoU(
preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat, cat_miou, shapenet.seg_classes)
cat,
cat_miou,
shapenet.seg_classes,
)
for _, v in cat_miou.items():
if v[1] > 0:
miou += v[0]
count += v[1]
per_cat_miou += v[0] / v[1]
per_cat_count += 1
tq.set_postfix({
'mIoU': '%.5f' % (miou / count),
'per Category mIoU': '%.5f' % (miou / count)})
tq.set_postfix(
{
"mIoU": "%.5f" % (miou / count),
"per Category mIoU": "%.5f" % (miou / count),
}
)
if per_cat_verbose:
print("Per-Category mIoU:")
for k, v in cat_miou.items():
......@@ -155,11 +178,11 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dev = "cpu"
if args.model == 'pointnet':
if args.model == "pointnet":
net = PointNetPartSeg(50, 3, 2048)
elif args.model == 'pointnet2_ssg':
elif args.model == "pointnet2_ssg":
net = PointNet2SSGPartSeg(50, batch_size, input_dims=6)
elif args.model == 'pointnet2_msg':
elif args.model == "pointnet2_msg":
net = PointNet2MSGPartSeg(50, batch_size, input_dims=6)
net = net.to(dev)
......@@ -180,43 +203,109 @@ if args.tensorboard:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
# Select 50 distinct colors for different parts
color_map = torch.tensor([
[47, 79, 79],[139, 69, 19],[112, 128, 144],[85, 107, 47],[139, 0, 0],[128, 128, 0],[72, 61, 139],[0, 128, 0],[188, 143, 143],[60, 179, 113],
[205, 133, 63],[0, 139, 139],[70, 130, 180],[205, 92, 92],[154, 205, 50],[0, 0, 139],[50, 205, 50],[250, 250, 250],[218, 165, 32],[139, 0, 139],
[10, 10, 10],[176, 48, 96],[72, 209, 204],[153, 50, 204],[255, 69, 0],[255, 145, 0],[0, 0, 205],[255, 255, 0],[0, 255, 0],[233, 150, 122],
[220, 20, 60],[0, 191, 255],[160, 32, 240],[192,192,192],[173, 255, 47],[218, 112, 214],[216, 191, 216],[255, 127, 80],[255, 0, 255],[100, 149, 237],
[128,128,128],[221, 160, 221],[144, 238, 144],[123, 104, 238],[255, 160, 122],[175, 238, 238],[238, 130, 238],[127, 255, 212],[255, 218, 185],[255, 105, 180],
])
color_map = torch.tensor(
[
[47, 79, 79],
[139, 69, 19],
[112, 128, 144],
[85, 107, 47],
[139, 0, 0],
[128, 128, 0],
[72, 61, 139],
[0, 128, 0],
[188, 143, 143],
[60, 179, 113],
[205, 133, 63],
[0, 139, 139],
[70, 130, 180],
[205, 92, 92],
[154, 205, 50],
[0, 0, 139],
[50, 205, 50],
[250, 250, 250],
[218, 165, 32],
[139, 0, 139],
[10, 10, 10],
[176, 48, 96],
[72, 209, 204],
[153, 50, 204],
[255, 69, 0],
[255, 145, 0],
[0, 0, 205],
[255, 255, 0],
[0, 255, 0],
[233, 150, 122],
[220, 20, 60],
[0, 191, 255],
[160, 32, 240],
[192, 192, 192],
[173, 255, 47],
[218, 112, 214],
[216, 191, 216],
[255, 127, 80],
[255, 0, 255],
[100, 149, 237],
[128, 128, 128],
[221, 160, 221],
[144, 238, 144],
[123, 104, 238],
[255, 160, 122],
[175, 238, 238],
[238, 130, 238],
[127, 255, 212],
[255, 218, 185],
[255, 105, 180],
]
)
# paint each point according to its pred
def paint(batched_points):
B, N = batched_points.shape
colored = color_map[batched_points].squeeze(2)
return colored
best_test_miou = 0
best_test_per_cat_miou = 0
for epoch in range(args.num_epochs):
data, preds, AvgLoss, AvgAcc, training_time = train(net, opt, scheduler, train_loader, dev)
data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev
)
if (epoch + 1) % 5 == 0:
print('Epoch #%d Testing' % epoch)
test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, (epoch + 1) % 5 ==0)
print("Epoch #%d Testing" % epoch)
test_miou, test_per_cat_miou = evaluate(
net, test_loader, dev, (epoch + 1) % 5 == 0
)
if test_miou > best_test_miou:
best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
print(
"Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
% (
test_miou,
best_test_miou,
test_per_cat_miou,
best_test_per_cat_miou,
)
)
# Tensorboard
if args.tensorboard:
colored = paint(preds)
writer.add_mesh('data', vertices=data, colors=colored, global_step=epoch)
writer.add_scalar('training time for one epoch', training_time, global_step=epoch)
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch)
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch)
writer.add_mesh(
"data", vertices=data, colors=colored, global_step=epoch
)
writer.add_scalar(
"training time for one epoch", training_time, global_step=epoch
)
writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch)
writer.add_scalar('best test mIoU', best_test_miou, global_step=epoch)
writer.add_scalar("test mIoU", test_miou, global_step=epoch)
writer.add_scalar(
"best test mIoU", best_test_miou, global_step=epoch
)
from statistics import mean
import torch
import torch.nn as nn
import torch.nn.functional as F
from statistics import mean
class LogisticRegressionClassifier(nn.Module):
''' Define a logistic regression classifier to evaluate the quality of embedding results
'''
"""Define a logistic regression classifier to evaluate the quality of embedding results"""
def __init__(self, nfeat, nclass):
super(LogisticRegressionClassifier, self).__init__()
self.lrc = nn.Linear(nfeat, nclass)
......@@ -14,6 +16,7 @@ class LogisticRegressionClassifier(nn.Module):
preds = self.lrc(x)
return preds
def _evaluate(model, features, labels, test_mask):
model.eval()
with torch.no_grad():
......@@ -24,8 +27,9 @@ def _evaluate(model, features, labels, test_mask):
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def _train_test_with_lrc(model, features, labels, train_mask, test_mask):
''' Under the pre-defined balanced train/test label setting, train a lrc to evaluate the embedding results. '''
"""Under the pre-defined balanced train/test label setting, train a lrc to evaluate the embedding results."""
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-06)
for _ in range(100):
model.train()
......@@ -34,15 +38,30 @@ def _train_test_with_lrc(model, features, labels, train_mask, test_mask):
loss_train = F.cross_entropy(output[train_mask], labels[train_mask])
loss_train.backward()
optimizer.step()
return _evaluate(model=model, features=features, labels=labels, test_mask=test_mask)
return _evaluate(
model=model, features=features, labels=labels, test_mask=test_mask
)
def evaluate_embeds(features, labels, train_mask, test_mask, n_classes, cuda, test_times=10):
print("Training a logistic regression classifier with the pre-defined train/test split setting ...")
def evaluate_embeds(
features, labels, train_mask, test_mask, n_classes, cuda, test_times=10
):
print(
"Training a logistic regression classifier with the pre-defined train/test split setting ..."
)
res_list = []
for _ in range(test_times):
model = LogisticRegressionClassifier(nfeat=features.shape[1], nclass=n_classes)
model = LogisticRegressionClassifier(
nfeat=features.shape[1], nclass=n_classes
)
if cuda:
model.cuda()
res = _train_test_with_lrc(model=model, features=features, labels=labels, train_mask=train_mask, test_mask=test_mask)
res = _train_test_with_lrc(
model=model,
features=features,
labels=labels,
train_mask=train_mask,
test_mask=test_mask,
)
res_list.append(res)
return mean(res_list)
import torch
import numpy as np
from collections import defaultdict
import numpy as np
import torch
def remove_unseen_classes_from_training(train_mask, labels, removed_class):
''' Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
"""Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
Input: train_mask, labels, removed_class
Output: train_mask_zs: the bool list only containing seen classes
'''
"""
train_mask_zs = train_mask.clone()
for i in range(train_mask_zs.numel()):
if train_mask_zs[i]==1 and (labels[i].item() in removed_class):
train_mask_zs[i]=0
if train_mask_zs[i] == 1 and (labels[i].item() in removed_class):
train_mask_zs[i] = 0
return train_mask_zs
def get_class_set(labels):
''' Get the class set.
"""Get the class set.
Input: labels [l, [c1, c2, ..]]
Output:the labeled class set dict_keys([k1, k2, ..])
'''
"""
mydict = {}
for y in labels:
for label in y:
mydict[int(label)] = 1
return mydict.keys()
def get_label_attributes(train_mask_zs, nodeids, labellist, features):
''' Get the class-center (semanic knowledge) of each seen class.
"""Get the class-center (semanic knowledge) of each seen class.
Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])
Input: train_mask_zs, nodeids, labellist, features
Output: label_attribute{}: label -> average_labeled_node_features (class centers)
'''
"""
_, feat_num = features.shape
labels = get_class_set(labellist)
label_attribute_nodes = defaultdict(list)
......@@ -43,13 +47,14 @@ def get_label_attributes(train_mask_zs, nodeids, labellist, features):
label_attribute[int(label)] = np.mean(selected_features, axis=0)
return label_attribute
def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
''' Replace the original labels by their class-centers.
"""Replace the original labels by their class-centers.
For each label c in the training set, the following operations will be performed:
Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]
Input: train_mask_zs, labels, features
Output: Y_{semantic} [l, ft]: tensor
'''
"""
X = torch.LongTensor(range(features.shape[0]))
nodeids = []
labellist = []
......@@ -59,7 +64,12 @@ def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
labellist.append([str(i)])
# 1. get the semantic knowledge (class centers) of all seen classes
label_attribute = get_label_attributes(train_mask_zs=train_mask_zs, nodeids=nodeids, labellist=labellist, features=features.cpu().numpy())
label_attribute = get_label_attributes(
train_mask_zs=train_mask_zs,
nodeids=nodeids,
labellist=labellist,
features=features.cpu().numpy(),
)
# 2. replace original labels by their class centers (semantic knowledge)
res = np.zeros([len(nodeids), features.shape[1]])
......
import torch
import torch.nn as nn
from classify import evaluate_embeds
from label_utils import remove_unseen_classes_from_training, get_labeled_nodes_label_attribute
from utils import load_data, svd_feature, process_classids
from label_utils import (
get_labeled_nodes_label_attribute,
remove_unseen_classes_from_training,
)
from model import GCN, RECT_L
from utils import load_data, process_classids, svd_feature
def main(args):
g, features, labels, train_mask, test_mask, n_classes, cuda= load_data(args)
g, features, labels, train_mask, test_mask, n_classes, cuda = load_data(
args
)
# adopt any number of classes as the unseen classes (the first three classes by default)
removed_class=args.removed_class
if(len(removed_class)>n_classes):
raise ValueError('unseen number is greater than the number of classes: {}'.format(len(removed_class)))
removed_class = args.removed_class
if len(removed_class) > n_classes:
raise ValueError(
"unseen number is greater than the number of classes: {}".format(
len(removed_class)
)
)
for i in removed_class:
if i not in labels:
raise ValueError('class out of bounds: {}'.format(i))
raise ValueError("class out of bounds: {}".format(i))
# remove these unseen classes from the training set, to construct the zero-shot label setting
train_mask_zs = remove_unseen_classes_from_training(train_mask=train_mask, labels=labels, removed_class=removed_class)
print('after removing the unseen classes, seen class labeled node num:', sum(train_mask_zs).item())
train_mask_zs = remove_unseen_classes_from_training(
train_mask=train_mask, labels=labels, removed_class=removed_class
)
print(
"after removing the unseen classes, seen class labeled node num:",
sum(train_mask_zs).item(),
)
if args.model_opt == 'RECT-L':
model = RECT_L(g=g, in_feats=args.n_hidden, n_hidden=args.n_hidden, activation=nn.PReLU())
if args.model_opt == "RECT-L":
model = RECT_L(
g=g,
in_feats=args.n_hidden,
n_hidden=args.n_hidden,
activation=nn.PReLU(),
)
if cuda:
model.cuda()
features = svd_feature(features=features, d=args.n_hidden)
attribute_labels = get_labeled_nodes_label_attribute(train_mask_zs=train_mask_zs, labels=labels, features=features, cuda=cuda)
loss_fcn = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
attribute_labels = get_labeled_nodes_label_attribute(
train_mask_zs=train_mask_zs,
labels=labels,
features=features,
cuda=cuda,
)
loss_fcn = nn.MSELoss(reduction="sum")
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
for epoch in range(args.n_epochs):
model.train()
optimizer.zero_grad()
logits = model(features)
loss_train = loss_fcn(attribute_labels, logits[train_mask_zs])
print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item()))
print(
"Epoch {:d} | Train Loss {:.5f}".format(
epoch + 1, loss_train.item()
)
)
loss_train.backward()
optimizer.step()
model.eval()
embeds = model.embed(features)
elif args.model_opt == 'GCN':
model = GCN(g=g, in_feats=features.shape[1],
n_hidden=args.n_hidden, n_classes=n_classes-len(removed_class),
activation=nn.PReLU(), dropout=args.dropout)
elif args.model_opt == "GCN":
model = GCN(
g=g,
in_feats=features.shape[1],
n_hidden=args.n_hidden,
n_classes=n_classes - len(removed_class),
activation=nn.PReLU(),
dropout=args.dropout,
)
if cuda:
model.cuda()
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
for epoch in range(args.n_epochs):
model.train()
......@@ -56,42 +94,70 @@ def main(args):
labels_train = process_classids(labels_temp=labels[train_mask_zs])
loss_train = loss_fcn(logits[train_mask_zs], labels_train)
optimizer.zero_grad()
print('Epoch {:d} | Train Loss {:.5f}'.format(epoch + 1, loss_train.item()))
print(
"Epoch {:d} | Train Loss {:.5f}".format(
epoch + 1, loss_train.item()
)
)
loss_train.backward()
optimizer.step()
model.eval()
embeds = model.embed(features)
elif args.model_opt == 'NodeFeats':
elif args.model_opt == "NodeFeats":
embeds = svd_feature(features)
# evaluate the quality of embedding results with the original balanced labels, to assess the model performance (as suggested in the paper)
res = evaluate_embeds(features=embeds, labels=labels, train_mask=train_mask, test_mask=test_mask, n_classes=n_classes, cuda=cuda)
res = evaluate_embeds(
features=embeds,
labels=labels,
train_mask=train_mask,
test_mask=test_mask,
n_classes=n_classes,
cuda=cuda,
)
print("Test Accuracy of {:s}: {:.4f}".format(args.model_opt, res))
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='MODEL')
parser.add_argument("--model-opt", type=str, default='RECT-L',
choices=['RECT-L', 'GCN', 'NodeFeats'],
help="model option")
parser.add_argument("--dataset", type=str, default='cora',
choices=['cora', 'citeseer'],
help="dataset")
parser.add_argument("--dropout", type=float, default=0.0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--removed-class", type=int, nargs='*', default=[0, 1, 2],
help="remove the unseen classes")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=200,
help="number of hidden gcn units")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser = argparse.ArgumentParser(description="MODEL")
parser.add_argument(
"--model-opt",
type=str,
default="RECT-L",
choices=["RECT-L", "GCN", "NodeFeats"],
help="model option",
)
parser.add_argument(
"--dataset",
type=str,
default="cora",
choices=["cora", "citeseer"],
help="dataset",
)
parser.add_argument(
"--dropout", type=float, default=0.0, help="dropout probability"
)
parser.add_argument("--gpu", type=int, default=0, help="gpu")
parser.add_argument(
"--removed-class",
type=int,
nargs="*",
default=[0, 1, 2],
help="remove the unseen classes",
)
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--n-epochs", type=int, default=200, help="number of training epochs"
)
parser.add_argument(
"--n-hidden", type=int, default=200, help="number of hidden gcn units"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
)
args = parser.parse_args()
main(args)
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_classes, activation, dropout):
super(GCN, self).__init__()
......
import torch
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def load_data(args):
if args.dataset == 'cora':
if args.dataset == "cora":
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
elif args.dataset == "citeseer":
data = CiteseerGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
if args.gpu < 0:
cuda = False
else:
cuda = True
g = g.int().to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
test_mask = g.ndata["test_mask"]
g = dgl.add_self_loop(g)
return g, features, labels, train_mask, test_mask, data.num_classes, cuda
def svd_feature(features, d=200):
''' Get 200-dimensional node features, to avoid curse of dimensionality
'''
if( features.shape[1] <= d ): return features
"""Get 200-dimensional node features, to avoid curse of dimensionality"""
if features.shape[1] <= d:
return features
U, S, VT = torch.svd(features)
res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))
return res
def process_classids(labels_temp):
''' Reorder the remaining classes with unseen classes removed.
"""Reorder the remaining classes with unseen classes removed.
Input: the label only removing unseen classes
Output: the label with reordered classes
'''
"""
labeldict = {}
num=0
num = 0
for i in labels_temp:
labeldict[int(i)]=1
labellist=sorted(labeldict)
labeldict[int(i)] = 1
labellist = sorted(labeldict)
for label in labellist:
labeldict[int(label)]=num
num=num+1
labeldict[int(label)] = num
num = num + 1
for i in range(labels_temp.numel()):
labels_temp[i]=labeldict[int(labels_temp[i])]
labels_temp[i] = labeldict[int(labels_temp[i])]
return labels_temp
......@@ -3,24 +3,26 @@ Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import numpy as np
import time
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import EntityClassify
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def main(args):
# load graph data
if args.dataset == 'aifb':
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == 'mutag':
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == 'bgs':
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == 'am':
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()
......@@ -28,11 +30,11 @@ def main(args):
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels')
labels = g.nodes[category].data.pop("labels")
category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes):
if ntype == category:
......@@ -40,8 +42,8 @@ def main(args):
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
val_idx = train_idx[: len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5 :]
else:
val_idx = train_idx
......@@ -49,25 +51,29 @@ def main(args):
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
g = g.to('cuda:%d' % args.gpu)
g = g.to("cuda:%d" % args.gpu)
labels = labels.cuda()
train_idx = train_idx.cuda()
test_idx = test_idx.cuda()
# create model
model = EntityClassify(g,
model = EntityClassify(
g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop)
use_self_loop=args.use_self_loop,
)
if use_cuda:
model.cuda()
# optimizer
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
optimizer = th.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.l2norm
)
# training loop
print("start training...")
......@@ -85,11 +91,23 @@ def main(args):
if epoch > 5:
dur.append(t1 - t0)
train_acc = th.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
train_acc = th.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur)))
val_acc = th.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
print(
"Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".format(
epoch,
train_acc,
loss.item(),
val_acc,
val_loss.item(),
np.average(dur),
)
)
print()
if args.model_path is not None:
th.save(model.state_dict(), args.model_path)
......@@ -97,37 +115,59 @@ def main(args):
model.eval()
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
test_acc = th.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print(
"Test Acc: {:.4f} | Test loss: {:.4f}".format(
test_acc, test_loss.item()
)
)
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str, default=None,
help='path for save the model')
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument(
"--dropout", type=float, default=0, help="dropout probability"
)
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden units"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-bases",
type=int,
default=-1,
help="number of filter weight matrices, default: -1 [use all]",
)
parser.add_argument(
"--n-layers", type=int, default=2, help="number of propagation rounds"
)
parser.add_argument(
"-e",
"--n-epochs",
type=int,
default=50,
help="number of training epochs",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="dataset to use"
)
parser.add_argument(
"--model_path", type=str, default=None, help="path for save the model"
)
parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
fp.add_argument("--validation", dest="validation", action="store_true")
fp.add_argument("--testing", dest="validation", action="store_false")
parser.set_defaults(validation=True)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment