"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "cf0ccd406e38a63bdb984578ba742ca3c8ab81b8"
Unverified Commit 558673e1 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Model] PointNet and PointNet++ for point cloud (#1510)



* commit patch

* commit patch

* pointnet basic

* fix data

* reorg

* reorg

* temp status

* remove validate set

* add partseg data and model

* partseg miou

* clean up

* fix loss

* network definition match paper

* fix

* fix miou

* update data format

* fix

* fix

* working pointnet ssg cls

* avoid some pytorch bug

* fix script

* update hyperparams

* add msg module

* try different dataset

* update new dataset info

* quick fix to subgraph

* fix speed

* update training

* update

* fix bs

* update docstring

* update

* update

* remove parallel reduction in fps

* switch to kernel fps, training is 30% faster
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-20-181.us-west-2.compute.internal>
parent 8a20b6c1
import numpy as np
import warnings
import os
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
1. Initialize the sample set S with a random point
2. Pick point P not in S, which maximizes the distance d(P, S)
3. Repeat step 2 until |S| = npoint
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
"""
Input:
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
normal_channel: whether to use additional channel
cache_size: the cache size of in-memory point clouds
"""
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d'%(split,len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints,:]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
def __getitem__(self, index):
return self._get_item(index)
\ No newline at end of file
PointNet and PointNet++ for Point Cloud Classification
====
This is a reproduction of the papers
- [PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation](https://arxiv.org/abs/1612.00593).
- [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413).
# Performance
| Model | Dataset | Metric | Score |
|-----------------|------------|----------|-------|
| PointNet | ModelNet40 | Accuracy | 89.3 |
| PointNet | ShapeNet | mIoU | 83.6 |
| PointNet++(SSG) | ModelNet40 | Accuracy | 93.26 |
| PointNet++(MSG) | ModelNet40 | Accuracy | 93.26 |
# How to Run
For point cloud classification, run with
```python
python train_cls.py
```
For point cloud part-segmentation, run with
```python
python train_partseg.py
```
import os, json, tqdm
import numpy as np
import dgl
from zipfile import ZipFile
from torch.utils.data import Dataset
from scipy.sparse import csr_matrix
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points
self.normal_channel = normal_channel
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path):
download(SHAPENET_DOWNLOAD_URL, local_path)
with ZipFile(local_path) as z:
z.extractall(path=download_path)
synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
self.synset_dict = {}
for syn in synset:
self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3],
'Bag': [4, 5],
'Cap': [6, 7],
'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15],
'Earphone': [16, 17, 18],
'Guitar': [19, 20, 21],
'Knife': [22, 23],
'Lamp': [24, 25, 26, 27],
'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35],
'Mug': [36, 37],
'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43],
'Skateboard': [44, 45, 46],
'Table': [47, 48, 49]}
train_split_json = 'shuffled_train_file_list.json'
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json'
split_path = os.path.join(data_path, 'train_test_split')
with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)
def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)
def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)
def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)
class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True):
super(ShapeNetDataset, self).__init__()
self.mode = mode
self.num_points = num_points
if not normal_channel:
self.dim = 3
else:
self.dim = 6
if mode == 'train':
self.file_list = shapenet.train_file_list
elif mode == 'valid':
self.file_list = shapenet.val_file_list
elif mode == 'test':
self.file_list = shapenet.test_file_list
elif mode == 'trainval':
self.file_list = shapenet.train_file_list + shapenet.val_file_list
else:
raise "Not supported `mode`"
data_list = []
label_list = []
category_list = []
print('Loading data from split ' + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float)
data_list.append(data[:, 0:self.dim])
label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
self.data = data_list
self.label = label_list
self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
return x
def __len__(self):
return len(self.data)
def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
x = self.data[i][inds,:self.dim]
y = self.label[i][inds]
cat = self.category[i]
if self.mode == 'train':
x = self.translate(x, size=self.dim)
x = x.astype(np.float)
y = y.astype(np.int)
return x, y, cat
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import dgl
import dgl.function as fn
from dgl.geometry.pytorch import FarthestPointSampler
'''
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
def square_distance(src, dst):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
class FixedRadiusNearNeighbors(nn.Module):
'''
Find the neighbors with-in a fixed radius
'''
def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
device = pos.device
B, N, _ = pos.shape
center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :self.n_neighbor]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
class FixedRadiusNNGraph(nn.Module):
'''
Build NN graph
'''
def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)
def forward(self, pos, centroids, feat=None):
dev = pos.device
group_idx = self.frnn(pos, centroids)
B, N, _ = pos.shape
glist = []
for i in range(B):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
g = dgl.DGLGraph((src_idx.cpu(), dst_idx.cpu()), readonly=True)
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
'''
Compute the input feature from neighbors
'''
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos']
if 'feat' in edges.src:
res = torch.cat([pos, edges.src['feat']], 1)
else:
res = pos
return {'agg_feat': res}
class PointNetConv(nn.Module):
'''
Feature aggregation
'''
def __init__(self, sizes, batch_size):
super(PointNetConv, self).__init__()
self.batch_size = batch_size
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 1, 2)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h, 3)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
def group_all(self, pos, feat):
'''
Feature aggretation and pooling for the non-sampling layer
'''
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
h = pos
shape = h.shape
h = h.permute(0, 2, 1).view(shape[0], shape[2], shape[1], 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h[:, :, :, 0], 2)[0]
return h
class SAModule(nn.Module):
"""
The Set Abstraction Layer
"""
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
group_all=False):
super(SAModule, self).__init__()
self.group_all = group_all
if not group_all:
self.fps = FarthestPointSampler(npoints)
self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
self.message = RelativePositionMessage(n_neighbor)
self.conv = PointNetConv(mlp_sizes, batch_size)
self.batch_size = batch_size
def forward(self, pos, feat):
if self.group_all:
return self.conv.group_all(pos, feat)
centroids = self.fps(pos)
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res
class SAMSGModule(nn.Module):
"""
The Set Abstraction Multi-Scale grouping Layer
"""
def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list):
super(SAMSGModule, self).__init__()
self.batch_size = batch_size
self.group_size = len(radius_list)
self.fps = FarthestPointSampler(npoints)
self.frnn_graph_list = nn.ModuleList()
self.message_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for i in range(self.group_size):
self.frnn_graph_list.append(FixedRadiusNNGraph(radius_list[i],
n_neighbor_list[i]))
self.message_list.append(RelativePositionMessage(n_neighbor_list[i]))
self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))
def forward(self, pos, feat):
centroids = self.fps(pos)
feat_res_list = []
for i in range(self.group_size):
g = self.frnn_graph_list[i](pos, centroids, feat)
g.update_all(self.message_list[i], self.conv_list[i])
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
if i == 0:
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
feat_res_list.append(feat_res)
feat_res = torch.cat(feat_res_list, 2)
return pos_res, feat_res
class PointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
super(PointNet2SSGCls, self).__init__()
self.input_dims = input_dims
self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128])
self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
group_all=True)
self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(dropout_prob)
self.mlp2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(dropout_prob)
self.mlp_out = nn.Linear(256, output_classes)
def forward(self, x):
if x.shape[-1] > 3:
pos = x[:, :, :3]
feat = x[:, :, 3:]
else:
pos = x
feat = None
pos, feat = self.sa_module1(pos, feat)
pos, feat = self.sa_module2(pos, feat)
h = self.sa_module3(pos, feat)
h = self.mlp1(h)
h = self.bn1(h)
h = F.relu(h)
h = self.drop1(h)
h = self.mlp2(h)
h = self.bn2(h)
h = F.relu(h)
h = self.drop2(h)
out = self.mlp_out(h)
return out
class PointNet2MSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
super(PointNet2MSGCls, self).__init__()
self.input_dims = input_dims
self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [16, 32, 128],
[[input_dims, 32, 32, 64], [input_dims, 64, 64, 128],
[input_dims, 64, 96, 128]])
self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.2, 0.4, 0.8], [32, 64, 128],
[[320 + 3, 64, 64, 128], [320 + 3, 128, 128, 256],
[320 + 3, 128, 128, 256]])
self.sa_module3 = SAModule(None, batch_size, None, [640 + 3, 256, 512, 1024],
group_all=True)
self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(dropout_prob)
self.mlp2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(dropout_prob)
self.mlp_out = nn.Linear(256, output_classes)
def forward(self, x):
if x.shape[-1] > 3:
pos = x[:, :, :3]
feat = x[:, :, 3:]
else:
pos = x
feat = None
pos, feat = self.sa_msg_module1(pos, feat)
pos, feat = self.sa_msg_module2(pos, feat)
h = self.sa_module3(pos, feat)
h = self.mlp1(h)
h = self.bn1(h)
h = F.relu(h)
h = self.drop1(h)
h = self.mlp2(h)
h = self.bn2(h)
h = F.relu(h)
h = self.drop2(h)
out = self.mlp_out(h)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class PointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64,
dropout_prob=0.5, use_transform=True):
super(PointNetCls, self).__init__()
self.input_dims = input_dims
self.conv1 = nn.ModuleList()
self.conv1.append(nn.Conv1d(input_dims, conv1_dim, 1))
self.conv1.append(nn.Conv1d(conv1_dim, conv1_dim, 1))
self.conv1.append(nn.Conv1d(conv1_dim, conv1_dim, 1))
self.bn1 = nn.ModuleList()
self.bn1.append(nn.BatchNorm1d(conv1_dim))
self.bn1.append(nn.BatchNorm1d(conv1_dim))
self.bn1.append(nn.BatchNorm1d(conv1_dim))
self.conv2 = nn.ModuleList()
self.conv2.append(nn.Conv1d(conv1_dim, conv1_dim * 2, 1))
self.conv2.append(nn.Conv1d(conv1_dim * 2, conv1_dim * 16, 1))
self.bn2 = nn.ModuleList()
self.bn2.append(nn.BatchNorm1d(conv1_dim * 2))
self.bn2.append(nn.BatchNorm1d(conv1_dim * 16))
self.maxpool = nn.MaxPool1d(conv1_dim * 16)
self.pool_feat_len = conv1_dim * 16
self.mlp3 = nn.ModuleList()
self.mlp3.append(nn.Linear(conv1_dim * 16, conv1_dim * 8))
self.mlp3.append(nn.Linear(conv1_dim * 8, conv1_dim * 4))
self.bn3 = nn.ModuleList()
self.bn3.append(nn.BatchNorm1d(conv1_dim * 8))
self.bn3.append(nn.BatchNorm1d(conv1_dim * 4))
self.dropout = nn.Dropout(0.3)
self.mlp_out = nn.Linear(conv1_dim * 4, output_classes)
self.use_transform = use_transform
if use_transform:
self.transform1 = TransformNet(input_dims)
self.trans_bn1 = nn.BatchNorm1d(input_dims)
self.transform2 = TransformNet(conv1_dim)
self.trans_bn2 = nn.BatchNorm1d(conv1_dim)
def forward(self, x):
batch_size = x.shape[0]
h = x.permute(0, 2, 1)
if self.use_transform:
trans = self.transform1(h)
h = h.transpose(2, 1)
h = torch.bmm(h, trans)
h = h.transpose(2, 1)
h = F.relu(self.trans_bn1(h))
for conv, bn in zip(self.conv1, self.bn1):
h = conv(h)
h = bn(h)
h = F.relu(h)
if self.use_transform:
trans = self.transform2(h)
h = h.transpose(2, 1)
h = torch.bmm(h, trans)
h = h.transpose(2, 1)
h = F.relu(self.trans_bn2(h))
for conv, bn in zip(self.conv2, self.bn2):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp3, self.bn3):
h = mlp(h)
h = bn(h)
h = F.relu(h)
h = self.dropout(h)
out = self.mlp_out(h)
return out
class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64):
super(TransformNet, self).__init__()
self.conv = nn.ModuleList()
self.conv.append(nn.Conv1d(input_dims, conv1_dim, 1))
self.conv.append(nn.Conv1d(conv1_dim, conv1_dim * 2, 1))
self.conv.append(nn.Conv1d(conv1_dim * 2, conv1_dim * 16, 1))
self.bn = nn.ModuleList()
self.bn.append(nn.BatchNorm1d(conv1_dim))
self.bn.append(nn.BatchNorm1d(conv1_dim * 2))
self.bn.append(nn.BatchNorm1d(conv1_dim * 16))
self.maxpool = nn.MaxPool1d(conv1_dim * 16)
self.pool_feat_len = conv1_dim * 16
self.mlp2 = nn.ModuleList()
self.mlp2.append(nn.Linear(conv1_dim * 16, conv1_dim * 8))
self.mlp2.append(nn.Linear(conv1_dim * 8, conv1_dim * 4))
self.bn2 = nn.ModuleList()
self.bn2.append(nn.BatchNorm1d(conv1_dim * 8))
self.bn2.append(nn.BatchNorm1d(conv1_dim * 4))
self.input_dims = input_dims
self.mlp_out = nn.Linear(conv1_dim * 4, input_dims * input_dims)
def forward(self, h):
batch_size = h.shape[0]
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h)
h = bn(h)
h = F.relu(h)
out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
if out.is_cuda:
iden = iden.cuda()
out = out + iden
out = out.view(-1, self.input_dims, self.input_dims)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class PointNetPartSeg(nn.Module):
def __init__(self, output_classes, input_dims=3, num_points=2048,
use_transform=True):
super(PointNetPartSeg, self).__init__()
self.input_dims = input_dims
self.conv1 = nn.ModuleList()
self.conv1.append(nn.Conv1d(input_dims, 64, 1))
self.conv1.append(nn.Conv1d(64, 128, 1))
self.conv1.append(nn.Conv1d(128, 128, 1))
self.bn1 = nn.ModuleList()
self.bn1.append(nn.BatchNorm1d(64))
self.bn1.append(nn.BatchNorm1d(128))
self.bn1.append(nn.BatchNorm1d(128))
self.conv2 = nn.ModuleList()
self.conv2.append(nn.Conv1d(128, 512, 1))
self.bn2 = nn.ModuleList()
self.bn2.append(nn.BatchNorm1d(512))
self.conv_max = nn.Conv1d(512, 2048, 1)
self.bn_max = nn.BatchNorm1d(2048)
self.maxpool = nn.MaxPool1d(num_points)
self.pool_feat_len = 2048
self.conv3 = nn.ModuleList()
self.conv3.append(nn.Conv1d(2048 + 64 + 128*3 + 512 + 16, 256, 1))
self.conv3.append(nn.Conv1d(256, 256, 1))
self.conv3.append(nn.Conv1d(256, 128, 1))
self.bn3 = nn.ModuleList()
self.bn3.append(nn.BatchNorm1d(256))
self.bn3.append(nn.BatchNorm1d(256))
self.bn3.append(nn.BatchNorm1d(128))
self.conv_out = nn.Conv1d(128, output_classes, 1)
self.use_transform = use_transform
if use_transform:
self.transform1 = TransformNet(self.input_dims)
self.trans_bn1 = nn.BatchNorm1d(self.input_dims)
self.transform2 = TransformNet(128)
self.trans_bn2 = nn.BatchNorm1d(128)
def forward(self, x, cat_vec=None):
batch_size = x.shape[0]
h = x.permute(0, 2, 1)
num_points = h.shape[2]
if self.use_transform:
trans = self.transform1(h)
h = h.transpose(2, 1)
h = torch.bmm(h, trans)
h = h.transpose(2, 1)
h = F.relu(self.trans_bn1(h))
mid_feat = []
for conv, bn in zip(self.conv1, self.bn1):
h = conv(h)
h = bn(h)
h = F.relu(h)
mid_feat.append(h)
if self.use_transform:
trans = self.transform2(h)
h = h.transpose(2, 1)
h = torch.bmm(h, trans)
h = h.transpose(2, 1)
h = F.relu(self.trans_bn2(h))
mid_feat.append(h)
for conv, bn in zip(self.conv2, self.bn2):
h = conv(h)
h = bn(h)
h = F.relu(h)
mid_feat.append(h)
h = self.conv_max(h)
h = self.bn_max(h)
h = self.maxpool(h).view(batch_size, -1, 1).repeat(1, 1, num_points)
mid_feat.append(h)
if cat_vec is not None:
mid_feat.append(cat_vec)
h = torch.cat(mid_feat, 1)
for conv, bn in zip(self.conv3, self.bn3):
h = conv(h)
h = bn(h)
h = F.relu(h)
out = self.conv_out(h)
return out
class TransformNet(nn.Module):
def __init__(self, input_dims=3, num_points=2048):
super(TransformNet, self).__init__()
self.conv = nn.ModuleList()
self.conv.append(nn.Conv1d(input_dims, 64, 1))
self.conv.append(nn.Conv1d(64, 128, 1))
self.conv.append(nn.Conv1d(128, 1024, 1))
self.bn = nn.ModuleList()
self.bn.append(nn.BatchNorm1d(64))
self.bn.append(nn.BatchNorm1d(128))
self.bn.append(nn.BatchNorm1d(1024))
self.maxpool = nn.MaxPool1d(num_points)
self.pool_feat_len = 1024
self.mlp2 = nn.ModuleList()
self.mlp2.append(nn.Linear(1024, 512))
self.mlp2.append(nn.Linear(512, 256))
self.bn2 = nn.ModuleList()
self.bn2.append(nn.BatchNorm1d(512))
self.bn2.append(nn.BatchNorm1d(256))
self.input_dims = input_dims
self.mlp_out = nn.Linear(256, input_dims * input_dims)
def forward(self, h):
batch_size = h.shape[0]
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h)
h = bn(h)
h = F.relu(h)
out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
if out.is_cuda:
iden = iden.cuda()
out = out + iden
out = out.view(-1, self.input_dims, self.input_dims)
return out
class PartSegLoss(nn.Module):
def __init__(self, eps=0.2):
super(PartSegLoss, self).__init__()
self.eps = eps
self.loss = nn.CrossEntropyLoss()
def forward(self, logits, y):
num_classes = logits.shape[1]
logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)
loss = self.loss(logits, y)
return loss
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
'''
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc
import torch
torch.backends.cudnn.enabled = False
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
from functools import partial
import tqdm
import urllib
import os
import argparse
# from dataset import ModelNet
import provider
from ModelNetDataLoader import ModelNetDataLoader
from pointnet_cls import PointNetCls
from pointnet2 import PointNet2SSGCls, PointNet2MSGCls
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet')
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(get_download_dir(), 'modelnet40_normal_resampled')
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download_path)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
loss_f = nn.CrossEntropyLoss()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
opt.zero_grad()
logits = net(data)
loss = loss_f(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
num_batches += 1
count += num_examples
loss = loss.item()
correct = (preds == label).sum().item()
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
scheduler.step()
def evaluate(net, test_loader, dev):
net.eval()
total_correct = 0
count = 0
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label in tq:
label = label[:,0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
logits = net(data)
_, preds = logits.max(1)
correct = (preds == label).sum().item()
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'pointnet':
net = PointNetCls(40, input_dims=6)
elif args.model == 'pointnet2_ssg':
net = PointNet2SSGCls(40, batch_size, input_dims=6)
elif args.model == 'pointnet2_msg':
net = PointNet2MSGCls(40, batch_size, input_dims=6)
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
best_test_acc = 0
for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0:
print('Epoch #%d Testing' % epoch)
test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc:
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from dgl.data.utils import download, get_download_dir
from functools import partial
import tqdm
import urllib
import os
import argparse
from ShapeNet import ShapeNet
from pointnet_partseg import PointNetPartSeg, PartSegLoss
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=16)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
def collate(samples):
graphs, cat = map(list, zip(*samples))
return dgl.batch(graphs), cat
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label, cat in tq:
num_examples = data.shape[0]
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long).view(-1)
opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1)
logits = net(data, cat_tensor)
loss = L(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
count += num_examples * 2048
loss = loss.item()
total_loss += loss
num_batches += 1
correct = (preds.view(-1) == label).sum().item()
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
scheduler.step()
def mIoU(preds, label, cat, cat_miou, seg_classes):
for i in range(preds.shape[0]):
shape_iou = 0
n = len(seg_classes[cat[i]])
for cls in seg_classes[cat[i]]:
pred_set = set(np.where(preds[i,:] == cls)[0])
label_set = set(np.where(label[i,:] == cls)[0])
union = len(pred_set.union(label_set))
inter = len(pred_set.intersection(label_set))
if union == 0:
shape_iou += 1
else:
shape_iou += inter / union
shape_iou /= n
cat_miou[cat[i]][0] += shape_iou
cat_miou[cat[i]][1] += 1
return cat_miou
def evaluate(net, test_loader, dev, per_cat_verbose=False):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
net.eval()
cat_miou = {}
for k in shapenet.seg_classes.keys():
cat_miou[k] = [0, 0]
miou = 0
count = 0
per_cat_miou = 0
per_cat_count = 0
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label, cat in tq:
num_examples = data.shape[0]
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1)
logits = net(data, cat_tensor)
_, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat, cat_miou, shapenet.seg_classes)
for _, v in cat_miou.items():
if v[1] > 0:
miou += v[0]
count += v[1]
per_cat_miou += v[0] / v[1]
per_cat_count += 1
tq.set_postfix({
'mIoU': '%.5f' % (miou / count),
'per Category mIoU': '%.5f' % (miou / count)})
if per_cat_verbose:
print("Per-Category mIoU:")
for k, v in cat_miou.items():
if v[1] > 0:
print("%s mIoU=%.5f" % (k, v[0] / v[1]))
else:
print("%s mIoU=%.5f" % (k, 1))
return miou / count, per_cat_miou / per_cat_count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dev = "cpu"
net = PointNetPartSeg(50, 3, 2048)
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.5)
L = PartSegLoss()
shapenet = ShapeNet(2048, normal_channel=False)
train_loader = CustomDataLoader(shapenet.trainval())
test_loader = CustomDataLoader(shapenet.test())
best_test_miou = 0
best_test_per_cat_miou = 0
for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 5 == 0:
print('Epoch #%d Testing' % epoch)
test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, (epoch + 1) % 5 ==0)
if test_miou > best_test_miou:
best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
......@@ -40,7 +40,7 @@ class FarthestPointSampler(nn.Module):
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = th.zeros((B * N), dtype=pos.dtype, device=device)
start_idx = th.randint(0, N - 1, (B, ), dtype=th.int, device=device)
result = th.zeros((self.npoints * B), dtype=th.int, device=device)
start_idx = th.randint(0, N - 1, (B, ), dtype=th.long, device=device)
result = th.zeros((self.npoints * B), dtype=th.long, device=device)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
......@@ -51,8 +51,8 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
// the last sampled point
int64_t sample_idx = (int64_t)(ret_data[ret_start + i]);
FloatType dist_max = (FloatType)(-1.);
int64_t dist_argmax = 0;
dist_argmax_ht[thread_idx] = 0;
dist_max_ht[thread_idx] = (FloatType)(-1.);
// multi-thread distance calculation
for (auto j = thread_idx; j < point_in_batch; j += THREADS) {
......@@ -67,36 +67,24 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
dist_data[array_start + j] = one_dist;
}
if (dist_data[array_start + j] > dist_max) {
dist_argmax = j;
dist_max = dist_data[array_start + j];
if (dist_data[array_start + j] > dist_max_ht[thread_idx]) {
dist_argmax_ht[thread_idx] = j;
dist_max_ht[thread_idx] = dist_data[array_start + j];
}
}
dist_max_ht[thread_idx] = dist_max;
dist_argmax_ht[thread_idx] = dist_argmax;
/*
* \brief Parallel Reduction
*
* Suppose the maximum is dist_max_ht[k], where 0 <= k < THREAD.
* After loop at j = 1, the maximum is propagated to [k-1].
* After loop at j = 2, the maximum is propagated to the range [k-3] to [k].
* After loop at j = 4, the maximum is propagated to the range [k-7] to [k].
* After loop at any j < THREADS, we can see [k - 2*j + 1] to [k] are all covered by the maximum.
* The max value of j is at least floor(THREAD / 2), and it is sufficient to cover [0] with the maximum.
*/
for (auto j = 1; j < THREADS; j *= 2) {
__syncthreads();
if ((thread_idx + j) < THREADS && dist_max_ht[thread_idx] < dist_max_ht[thread_idx + j]) {
dist_max_ht[thread_idx] = dist_max_ht[thread_idx + j];
dist_argmax_ht[thread_idx] = dist_argmax_ht[thread_idx + j];
}
}
__syncthreads();
if (thread_idx == 0) {
ret_data[ret_start + i + 1] = (IdType)(dist_argmax_ht[0]);
FloatType best = dist_max_ht[0];
int64_t best_idx = dist_argmax_ht[0];
for (auto j = 1; j < THREADS; j++) {
if (dist_max_ht[j] > best) {
best = dist_max_ht[j];
best_idx = dist_argmax_ht[j];
}
}
ret_data[ret_start + i + 1] = (IdType)(best_idx);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment