Unverified Commit ff345c2e authored by wcyjames's avatar wcyjames Committed by GitHub
Browse files

[Example] update PointNet and PointNet++ examples for Part Segmentation (#2547)



* [Model] update PointNet example for Part Segmentation

* Fixed issues with pointnet examples

* update the README

* Added image

* Fixed README and tensorboard arguments

* clean

* Add timing

* Update README.md

* Update README.md

Fixed a typo
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent b1840f49
PointNet and PointNet++ for Point Cloud Classification PointNet and PointNet++ for Point Cloud Classification and Segmentation
==== ====
This is a reproduction of the papers This is a reproduction of the papers
...@@ -7,13 +7,23 @@ This is a reproduction of the papers ...@@ -7,13 +7,23 @@ This is a reproduction of the papers
# Performance # Performance
| Model | Dataset | Metric | Score | ## Classification
|-----------------|------------|----------|-------| | Model | Dataset | Metric | Score - PyTorch | Score - DGL | Time(s) - PyTorch | Time(s) - DGL |
| PointNet | ModelNet40 | Accuracy | 89.3 | |-----------------|------------|----------|------------------|-------------|-------------------|---------------|
| PointNet | ShapeNet | mIoU | 83.6 | | PointNet | ModelNet40 | Accuracy | 89.2(Official) | 89.3 | 181.8 | 95.0 |
| PointNet++(SSG) | ModelNet40 | Accuracy | 93.26 | | PointNet++(SSG) | ModelNet40 | Accuracy | 92.4 | 93.3 | 182.6 | 133.7 |
| PointNet++(MSG) | ModelNet40 | Accuracy | 93.26 | | PointNet++(MSG) | ModelNet40 | Accuracy | 92.8 | 93.3 | 383.6 | 240.5 |
## Part Segmentation
| Model | Dataset | Metric | Score - PyTorch | Score - DGL | Time(s) - PyTorch | Time(s) - DGL |
|-----------------|------------|----------|-----------------|-------------|-------------------|---------------|
| PointNet | ShapeNet | mIoU | 84.3 | 83.6 | 251.6 | 234.0 |
| PointNet++(SSG) | ShapeNet | mIoU | 84.9 | 84.5 | 361.7 | 240.1 |
| PointNet++(MSG) | ShapeNet | mIoU | 85.4 | 84.6 | 817.3 | 821.8 |
+ Score - PyTorch are collected from [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch).
+ Time(s) are the average training time per epoch, measured on EC2 g4dn.4xlarge instance w/ Tesla T4 GPU.
# How to Run # How to Run
For point cloud classification, run with For point cloud classification, run with
...@@ -27,3 +37,13 @@ For point cloud part-segmentation, run with ...@@ -27,3 +37,13 @@ For point cloud part-segmentation, run with
```python ```python
python train_partseg.py python train_partseg.py
``` ```
## To Visualize Part Segmentation in Tensorboard
![Screenshot](vis.png)
First ``pip install tensorboard``
then run
```python
python train_partseg.py --tensorboard
```
To display in Tensorboard, run
``tensorboard --logdir=runs``
...@@ -39,7 +39,7 @@ def index_points(points, idx): ...@@ -39,7 +39,7 @@ def index_points(points, idx):
class FixedRadiusNearNeighbors(nn.Module): class FixedRadiusNearNeighbors(nn.Module):
''' '''
Find the neighbors with-in a fixed radius Ball Query - Find the neighbors with-in a fixed radius
''' '''
def __init__(self, radius, n_neighbor): def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__() super(FixedRadiusNearNeighbors, self).__init__()
...@@ -129,32 +129,34 @@ class PointNetConv(nn.Module): ...@@ -129,32 +129,34 @@ class PointNetConv(nn.Module):
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 1, 2) h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
h = F.relu(h) h = F.relu(h)
h = torch.max(h, 3)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {'new_feat': h}
def group_all(self, pos, feat): def group_all(self, pos, feat):
''' '''
Feature aggretation and pooling for the non-sampling layer Feature aggregation and pooling for the non-sampling layer
''' '''
if feat is not None: if feat is not None:
h = torch.cat([pos, feat], 2) h = torch.cat([pos, feat], 2)
else: else:
h = pos h = pos
shape = h.shape B, N, D = h.shape
h = h.permute(0, 2, 1).view(shape[0], shape[2], shape[1], 1) _, _, C = pos.shape
new_pos = torch.zeros(B, 1, C)
h = h.permute(0, 2, 1).view(B, -1, N, 1)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
h = F.relu(h) h = F.relu(h)
h = torch.max(h[:, :, :, 0], 2)[0] h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return h return new_pos, h
class SAModule(nn.Module): class SAModule(nn.Module):
""" """
...@@ -178,6 +180,7 @@ class SAModule(nn.Module): ...@@ -178,6 +180,7 @@ class SAModule(nn.Module):
centroids = self.fps(pos) centroids = self.fps(pos)
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata['new_feat'].shape[-1]
...@@ -207,6 +210,7 @@ class SAMSGModule(nn.Module): ...@@ -207,6 +210,7 @@ class SAMSGModule(nn.Module):
def forward(self, pos, feat): def forward(self, pos, feat):
centroids = self.fps(pos) centroids = self.fps(pos)
feat_res_list = [] feat_res_list = []
for i in range(self.group_size): for i in range(self.group_size):
g = self.frnn_graph_list[i](pos, centroids, feat) g = self.frnn_graph_list[i](pos, centroids, feat)
g.update_all(self.message_list[i], self.conv_list[i]) g.update_all(self.message_list[i], self.conv_list[i])
...@@ -217,9 +221,62 @@ class SAMSGModule(nn.Module): ...@@ -217,9 +221,62 @@ class SAMSGModule(nn.Module):
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim) feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
feat_res_list.append(feat_res) feat_res_list.append(feat_res)
feat_res = torch.cat(feat_res_list, 2) feat_res = torch.cat(feat_res_list, 2)
return pos_res, feat_res return pos_res, feat_res
class PointNet2FP(nn.Module):
"""
The Feature Propagation Layer
"""
def __init__(self, input_dims, sizes):
super(PointNet2FP, self).__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
sizes = [input_dims] + sizes
for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2):
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Input:
x1: input points position data, [B, N, C]
x2: sampled input points position data, [B, S, C]
feat1: input points data, [B, N, D]
feat2: input points data, [B, S, D]
Return:
new_feat: upsampled points data, [B, D', N]
"""
B, N, C = x1.shape
_, S, _ = x2.shape
if S == 1:
interpolated_feat = feat2.repeat(1, N, 1)
else:
dists = square_distance(x1, x2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feat = torch.sum(index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2)
if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
else:
new_feat = interpolated_feat
new_feat = new_feat.permute(0, 2, 1) # [B, D, S]
for i, conv in enumerate(self.convs):
bn = self.bns[i]
new_feat = F.relu(bn(conv(new_feat)))
return new_feat
class PointNet2SSGCls(nn.Module): class PointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4): def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
super(PointNet2SSGCls, self).__init__() super(PointNet2SSGCls, self).__init__()
...@@ -249,7 +306,7 @@ class PointNet2SSGCls(nn.Module): ...@@ -249,7 +306,7 @@ class PointNet2SSGCls(nn.Module):
feat = None feat = None
pos, feat = self.sa_module1(pos, feat) pos, feat = self.sa_module1(pos, feat)
pos, feat = self.sa_module2(pos, feat) pos, feat = self.sa_module2(pos, feat)
h = self.sa_module3(pos, feat) _, h = self.sa_module3(pos, feat)
h = self.mlp1(h) h = self.mlp1(h)
h = self.bn1(h) h = self.bn1(h)
...@@ -296,7 +353,7 @@ class PointNet2MSGCls(nn.Module): ...@@ -296,7 +353,7 @@ class PointNet2MSGCls(nn.Module):
feat = None feat = None
pos, feat = self.sa_msg_module1(pos, feat) pos, feat = self.sa_msg_module1(pos, feat)
pos, feat = self.sa_msg_module2(pos, feat) pos, feat = self.sa_msg_module2(pos, feat)
h = self.sa_module3(pos, feat) _, h = self.sa_module3(pos, feat)
h = self.mlp1(h) h = self.mlp1(h)
h = self.bn1(h) h = self.bn1(h)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from pointnet2 import SAModule, SAMSGModule, PointNet2FP
class PointNet2SSGPartSeg(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=6):
super(PointNet2SSGPartSeg, self).__init__()
#if normal_channel == true, input_dims = 6+3
self.input_dims = input_dims
self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32)
self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
group_all=True)
self.fp3 = PointNet2FP(1280, [256, 256])
self.fp2 = PointNet2FP(384, [256, 128])
# if normal_channel == true, 128+16+6+3
self.fp1 = PointNet2FP(128+16+6, [128, 128, 128])
self.conv1 = nn.Conv1d(128, 128, 1)
self.bn1 = nn.BatchNorm1d(128)
self.drop1 = nn.Dropout(0.5)
self.conv2 = nn.Conv1d(128, output_classes, 1)
def forward(self, x, cat_vec=None):
if x.shape[-1] > 3:
l0_pos = x[:, :, :3]
l0_feat = x
else:
l0_pos = x
l0_feat = x
# Set Abstraction layers
l1_pos, l1_feat = self.sa_module1(l0_pos, l0_feat) # l1_feat: [B, N, D]
l2_pos, l2_feat = self.sa_module2(l1_pos, l1_feat)
l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat) # [B, N, C], [B, D]
# Feature Propagation layers
l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1)) # l2_feat: [B, D, N]
l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))
l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)
l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))
# FC layers
feat = F.relu(self.bn1(self.conv1(l0_feat)))
out = self.drop1(feat)
out = self.conv2(out) # [B, output_classes, N]
return out
class PointNet2MSGPartSeg(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=6):
super(PointNet2MSGPartSeg, self).__init__()
self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [32, 64, 128],
[[input_dims, 32, 32, 64], [input_dims, 64, 64, 128],
[input_dims, 64, 96, 128]])
self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.4, 0.8], [64, 128],
[[128+128+64 +3, 128, 128, 256], [128+128+64 +3, 128, 196, 256]])
self.sa_module3 = SAModule(None, batch_size, None, [512 + 3, 256, 512, 1024],
group_all=True)
self.fp3 = PointNet2FP(1536, [256, 256])
self.fp2 = PointNet2FP(576, [256, 128])
# if normal_channel == true, 150 + 3
self.fp1 = PointNet2FP(150, [128, 128])
self.conv1 = nn.Conv1d(128, 128, 1)
self.bn1 = nn.BatchNorm1d(128)
self.drop1 = nn.Dropout(0.5)
self.conv2 = nn.Conv1d(128, output_classes, 1)
def forward(self, x, cat_vec=None):
if x.shape[-1] > 3:
l0_pos = x[:, :, :3]
l0_feat = x
else:
l0_pos = x
l0_feat = x
# Set Abstraction layers
l1_pos, l1_feat = self.sa_msg_module1(l0_pos, l0_feat)
l2_pos, l2_feat = self.sa_msg_module2(l1_pos, l1_feat)
l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat)
# Feature Propagation layers
l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1))
l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))
l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)
l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))
# FC layers
feat = F.relu(self.bn1(self.conv1(l0_feat)))
out = self.drop1(feat)
out = self.conv2(out)
return out
...@@ -12,17 +12,21 @@ import tqdm ...@@ -12,17 +12,21 @@ import tqdm
import urllib import urllib
import os import os
import argparse import argparse
import time
from ShapeNet import ShapeNet from ShapeNet import ShapeNet
from pointnet_partseg import PointNetPartSeg, PartSegLoss from pointnet_partseg import PointNetPartSeg, PartSegLoss
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet')
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250) parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=4) parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--tensorboard', action='store_true')
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
...@@ -48,6 +52,7 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -48,6 +52,7 @@ def train(net, opt, scheduler, train_loader, dev):
num_batches = 0 num_batches = 0
total_correct = 0 total_correct = 0
count = 0 count = 0
start = time.time()
with tqdm.tqdm(train_loader, ascii=True) as tq: with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label, cat in tq: for data, label, cat in tq:
num_examples = data.shape[0] num_examples = data.shape[0]
...@@ -72,10 +77,15 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -72,10 +77,15 @@ def train(net, opt, scheduler, train_loader, dev):
correct = (preds.view(-1) == label).sum().item() correct = (preds.view(-1) == label).sum().item()
total_correct += correct total_correct += correct
AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count
tq.set_postfix({ tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches), 'AvgLoss': '%.5f' % AvgLoss,
'AvgAcc': '%.5f' % (total_correct / count)}) 'AvgAcc': '%.5f' % AvgAcc})
scheduler.step() scheduler.step()
end = time.time()
return data, preds, AvgLoss, AvgAcc, end-start
def mIoU(preds, label, cat, cat_miou, seg_classes): def mIoU(preds, label, cat, cat_miou, seg_classes):
for i in range(preds.shape[0]): for i in range(preds.shape[0]):
...@@ -145,8 +155,13 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False): ...@@ -145,8 +155,13 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dev = "cpu" # dev = "cpu"
if args.model == 'pointnet':
net = PointNetPartSeg(50, 3, 2048)
elif args.model == 'pointnet2_ssg':
net = PointNet2SSGPartSeg(50, batch_size, input_dims=6)
elif args.model == 'pointnet2_msg':
net = PointNet2MSGPartSeg(50, batch_size, input_dims=6)
net = PointNetPartSeg(50, 3, 2048)
net = net.to(dev) net = net.to(dev)
if args.load_model_path: if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
...@@ -160,11 +175,31 @@ shapenet = ShapeNet(2048, normal_channel=False) ...@@ -160,11 +175,31 @@ shapenet = ShapeNet(2048, normal_channel=False)
train_loader = CustomDataLoader(shapenet.trainval()) train_loader = CustomDataLoader(shapenet.trainval())
test_loader = CustomDataLoader(shapenet.test()) test_loader = CustomDataLoader(shapenet.test())
# Tensorboard
if args.tensorboard:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
# Select 50 distinct colors for different parts
color_map = torch.tensor([
[47, 79, 79],[139, 69, 19],[112, 128, 144],[85, 107, 47],[139, 0, 0],[128, 128, 0],[72, 61, 139],[0, 128, 0],[188, 143, 143],[60, 179, 113],
[205, 133, 63],[0, 139, 139],[70, 130, 180],[205, 92, 92],[154, 205, 50],[0, 0, 139],[50, 205, 50],[250, 250, 250],[218, 165, 32],[139, 0, 139],
[10, 10, 10],[176, 48, 96],[72, 209, 204],[153, 50, 204],[255, 69, 0],[255, 145, 0],[0, 0, 205],[255, 255, 0],[0, 255, 0],[233, 150, 122],
[220, 20, 60],[0, 191, 255],[160, 32, 240],[192,192,192],[173, 255, 47],[218, 112, 214],[216, 191, 216],[255, 127, 80],[255, 0, 255],[100, 149, 237],
[128,128,128],[221, 160, 221],[144, 238, 144],[123, 104, 238],[255, 160, 122],[175, 238, 238],[238, 130, 238],[127, 255, 212],[255, 218, 185],[255, 105, 180],
])
# paint each point according to its pred
def paint(batched_points):
B, N = batched_points.shape
colored = color_map[batched_points].squeeze(2)
return colored
best_test_miou = 0 best_test_miou = 0
best_test_per_cat_miou = 0 best_test_per_cat_miou = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev) data, preds, AvgLoss, AvgAcc, training_time = train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
print('Epoch #%d Testing' % epoch) print('Epoch #%d Testing' % epoch)
test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, (epoch + 1) % 5 ==0) test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, (epoch + 1) % 5 ==0)
...@@ -175,3 +210,13 @@ for epoch in range(args.num_epochs): ...@@ -175,3 +210,13 @@ for epoch in range(args.num_epochs):
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % ( print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou)) test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
# Tensorboard
if args.tensorboard:
colored = paint(preds)
writer.add_mesh('data', vertices=data, colors=colored, global_step=epoch)
writer.add_scalar('training time for one epoch', training_time, global_step=epoch)
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch)
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch)
writer.add_scalar('best test mIoU', best_test_miou, global_step=epoch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment