Unverified Commit 4fb50be4 authored by esang's avatar esang Committed by GitHub
Browse files

[Model] Point transformer (#3284)



* some modifications for pointnet2

* temporarily save changes

* move files to new directory point_transformer

* implement point transformer for classification

* restore train_cls in pointnet

* implement point transformer for partseg

* fix point transformer for nan loss

* modify point transformer for cls

* modify training setting

* update transformer for cls

* update code

* update code for latest performance

* update the example index

* some minor changes
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent 79305862
......@@ -17,6 +17,10 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name="correct_and_smooth"></a> Huang et al. Combining Label Propagation and Simple Models Out-performs Graph Neural Networks. [Paper link](https://arxiv.org/abs/2010.13993).
- Example code: [PyTorch](../examples/pytorch/correct_and_smooth)
- Tags: efficiency, node classification, label propagation
- <a name="point_transformer"></a> Zhao et al. Point Transformer. [Paper link](http://arxiv.org/abs/2012.09164).
- Example code: [PyTorch](../examples/pytorch/pointcloud/point_transformer)
- Tags: point cloud classification, point cloud part-segmentation
## 2020
- <a name="eeg-gcnn"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html).
......
import numpy as np
import warnings
import os
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
1. Initialize the sample set S with a random point
2. Pick point P not in S, which maximizes the distance d(P, S)
3. Repeat step 2 until |S| = npoint
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
"""
Input:
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
normal_channel: whether to use additional channel
cache_size: the cache size of in-memory point clouds
"""
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
def __getitem__(self, index):
return self._get_item(index)
Point Transformer
====
> This model is implemented on August 27, 2021 when there is no official code released.
Thus we implemented this model based on the code from <https://github.com/qq456cvb/Point-Transformers>.
This is a reproduction of the paper: [Point Transformer](http://arxiv.org/abs/2012.09164).
# Performance
| Task | Dataset | Metric | Score - Paper | Score - DGL (Adam) | Score - DGL (SGD) | Time(s) - DGL |
|-----------------|------------|----------|------------------|-------------|-------------|-------------------|
| Classification | ModelNet40 | Accuracy | 93.7 | 92.0 | 91.5 | 117.0 |
| Part Segmentation | ShapeNet | mIoU | 86.6 | 84.3 | 85.1 | 260.0 |
+ Time(s) are the average training time per epoch, measured on EC2 p3.8xlarge instance w/ Tesla V100 GPU.
# How to Run
For point cloud classification, run with
```python
python train_cls.py --opt [sgd/adam]
```
For point cloud part-segmentation, run with
```python
python train_partseg.py --opt [sgd/adam]
```
import os, json, tqdm
import numpy as np
import dgl
from zipfile import ZipFile
from torch.utils.data import Dataset
from scipy.sparse import csr_matrix
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points
self.normal_channel = normal_channel
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path):
download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)
with ZipFile(local_path) as z:
z.extractall(path=download_path)
synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
self.synset_dict = {}
for syn in synset:
self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3],
'Bag': [4, 5],
'Cap': [6, 7],
'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15],
'Earphone': [16, 17, 18],
'Guitar': [19, 20, 21],
'Knife': [22, 23],
'Lamp': [24, 25, 26, 27],
'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35],
'Mug': [36, 37],
'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43],
'Skateboard': [44, 45, 46],
'Table': [47, 48, 49]}
train_split_json = 'shuffled_train_file_list.json'
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json'
split_path = os.path.join(data_path, 'train_test_split')
with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)
def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)
def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)
def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)
class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True):
super(ShapeNetDataset, self).__init__()
self.mode = mode
self.num_points = num_points
if not normal_channel:
self.dim = 3
else:
self.dim = 6
if mode == 'train':
self.file_list = shapenet.train_file_list
elif mode == 'valid':
self.file_list = shapenet.val_file_list
elif mode == 'test':
self.file_list = shapenet.test_file_list
elif mode == 'trainval':
self.file_list = shapenet.train_file_list + shapenet.val_file_list
else:
raise "Not supported `mode`"
data_list = []
label_list = []
category_list = []
print('Loading data from split ' + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float)
data_list.append(data[:, 0:self.dim])
label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
self.data = data_list
self.label = label_list
self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
return x
def __len__(self):
return len(self.data)
def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
x = self.data[i][inds,:self.dim]
y = self.label[i][inds]
cat = self.category[i]
if self.mode == 'train':
x = self.translate(x, size=self.dim)
x = x.astype(np.float)
y = y.astype(np.int)
return x, y, cat
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler
'''
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
def square_distance(src, dst):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(
device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
class KNearNeighbors(nn.Module):
'''
Find the k nearest neighbors
'''
def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor]
return group_idx
class KNNGraphBuilder(nn.Module):
'''
Build NN graph
'''
def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__()
self.n_neighbor = n_neighbor
self.knn = KNearNeighbors(n_neighbor)
def forward(self, pos, centroids, feat=None):
dev = pos.device
group_idx = self.knn(pos, centroids)
B, N, _ = pos.shape
glist = []
for i in range(B):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor,
src.shape[0] // centroids.shape[1])).view(-1)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
'''
Compute the input feature from neighbors
'''
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos']
if 'feat' in edges.src:
res = torch.cat([pos, edges.src['feat']], 1)
else:
res = pos
return {'agg_feat': res}
class KNNConv(nn.Module):
'''
Feature aggregation
'''
def __init__(self, sizes, batch_size):
super(KNNConv, self).__init__()
self.batch_size = batch_size
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(
self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
def group_all(self, pos, feat):
'''
Feature aggregation and pooling for the non-sampling layer
'''
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
h = pos
B, N, D = h.shape
_, _, C = pos.shape
new_pos = torch.zeros(B, 1, C)
h = h.permute(0, 2, 1).view(B, -1, N, 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h
class TransitionDown(nn.Module):
"""
The Transition Down Module
"""
def __init__(self, n_points, batch_size, mlp_sizes, n_neighbors=64):
super(TransitionDown, self).__init__()
self.n_points = n_points
self.frnn_graph = KNNGraphBuilder(n_neighbors)
self.message = RelativePositionMessage(n_neighbors)
self.conv = KNNConv(mlp_sizes, batch_size)
self.batch_size = batch_size
def forward(self, pos, feat):
centroids = farthest_point_sampler(pos, self.n_points)
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(
self.batch_size, -1, feat_dim)
return pos_res, feat_res
class FeaturePropagation(nn.Module):
"""
The FeaturePropagation Layer
"""
def __init__(self, input_dims, sizes):
super(FeaturePropagation, self).__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
sizes = [input_dims] + sizes
for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i-1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2):
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Input:
x1: input points position data, [B, N, C]
x2: sampled input points position data, [B, S, C]
feat1: input points data, [B, N, D]
feat2: input points data, [B, S, D]
Return:
new_feat: upsampled points data, [B, D', N]
"""
B, N, C = x1.shape
_, S, _ = x2.shape
if S == 1:
interpolated_feat = feat2.repeat(1, N, 1)
else:
dists = square_distance(x1, x2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feat = torch.sum(index_points(
feat2, idx) * weight.view(B, N, 3, 1), dim=2)
if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
else:
new_feat = interpolated_feat
new_feat = new_feat.permute(0, 2, 1) # [B, D, S]
for i, conv in enumerate(self.convs):
bn = self.bns[i]
new_feat = F.relu(bn(conv(new_feat)))
return new_feat
class SwapAxes(nn.Module):
def __init__(self, dim1=1, dim2=2):
super(SwapAxes, self).__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x):
return x.transpose(self.dim1, self.dim2)
class TransitionUp(nn.Module):
"""
The Transition Up Module
"""
def __init__(self, dim1, dim2, dim_out):
super(TransitionUp, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(dim1, dim_out),
SwapAxes(),
nn.BatchNorm1d(dim_out), # TODO
SwapAxes(),
nn.ReLU(),
)
self.fc2 = nn.Sequential(
nn.Linear(dim2, dim_out),
SwapAxes(),
nn.BatchNorm1d(dim_out), # TODO
SwapAxes(),
nn.ReLU(),
)
self.fp = FeaturePropagation(-1, [])
def forward(self, pos1, feat1, pos2, feat2):
h1 = self.fc1(feat1)
h2 = self.fc2(feat2)
h1 = self.fp(pos2, pos1, None, h1).transpose(1, 2)
return h1 + h2
import torch
from torch import nn
import numpy as np
from helper import square_distance, index_points, TransitionDown, TransitionUp
'''
Part of the code are adapted from
https://github.com/qq456cvb/Point-Transformers
'''
class PointTransformerBlock(nn.Module):
def __init__(self, input_dim, n_neighbors, transformer_dim=None):
super(PointTransformerBlock, self).__init__()
if transformer_dim is None:
transformer_dim = input_dim
self.fc1 = nn.Linear(input_dim, transformer_dim)
self.fc2 = nn.Linear(transformer_dim, input_dim)
self.fc_delta = nn.Sequential(
nn.Linear(3, transformer_dim),
nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim)
)
self.fc_gamma = nn.Sequential(
nn.Linear(transformer_dim, transformer_dim),
nn.ReLU(),
nn.Linear(transformer_dim, transformer_dim)
)
self.w_qs = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.w_ks = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.w_vs = nn.Linear(transformer_dim, transformer_dim, bias=False)
self.n_neighbors = n_neighbors
def forward(self, x, pos):
dists = square_distance(pos, pos)
knn_idx = dists.argsort()[:, :, :self.n_neighbors] # b x n x k
knn_pos = index_points(pos, knn_idx)
h = self.fc1(x)
q, k, v = self.w_qs(h), index_points(
self.w_ks(h), knn_idx), index_points(self.w_vs(h), knn_idx)
pos_enc = self.fc_delta(pos[:, :, None] - knn_pos) # b x n x k x f
attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
attn = torch.softmax(attn / np.sqrt(k.size(-1)),
dim=-2) # b x n x k x f
res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
res = self.fc2(res) + x
return res, attn
class PointTransformer(nn.Module):
def __init__(self, n_points, batch_size, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16):
super(PointTransformer, self).__init__()
self.fc = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.ptb = PointTransformerBlock(
hidden_dim, n_neighbors, transformer_dim)
self.transition_downs = nn.ModuleList()
self.transformers = nn.ModuleList()
for i in range(n_blocks):
block_hidden_dim = hidden_dim * 2 ** (i + 1)
block_n_points = n_points // (downsampling_rate ** (i + 1))
self.transition_downs.append(TransitionDown(block_n_points, batch_size, [
block_hidden_dim // 2 + 3, block_hidden_dim, block_hidden_dim], n_neighbors=n_neighbors))
self.transformers.append(
PointTransformerBlock(block_hidden_dim, n_neighbors, transformer_dim))
def forward(self, x):
if x.shape[-1] > 3:
pos = x[:, :, :3]
else:
pos = x
feat = x
h = self.fc(feat)
h, _ = self.ptb(h, pos)
hidden_state = [(pos, h)]
for td, tf in zip(self.transition_downs, self.transformers):
pos, h = td(pos, h)
h, _ = tf(h, pos)
hidden_state.append((pos, h))
return h, hidden_state
class PointTransformerCLS(nn.Module):
def __init__(self, out_classes, batch_size, n_points=1024, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16):
super(PointTransformerCLS, self).__init__()
self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors)
self.out = self.fc2 = nn.Sequential(
nn.Linear(hidden_dim * 2 ** (n_blocks), 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, out_classes)
)
def forward(self, x):
h, _ = self.backbone(x)
out = self.out(torch.mean(h, dim=1))
return out
class PointTransformerSeg(nn.Module):
def __init__(self, out_classes, batch_size, n_points=2048, feature_dim=3, n_blocks=4, downsampling_rate=4, hidden_dim=32, transformer_dim=None, n_neighbors=16):
super().__init__()
self.backbone = PointTransformer(
n_points, batch_size, feature_dim, n_blocks, downsampling_rate, hidden_dim, transformer_dim, n_neighbors)
self.fc = nn.Sequential(
nn.Linear(32 * 2 ** n_blocks, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 32 * 2 ** n_blocks)
)
self.ptb = PointTransformerBlock(
32 * 2 ** n_blocks, n_neighbors, transformer_dim)
self.n_blocks = n_blocks
self.transition_ups = nn.ModuleList()
self.transformers = nn.ModuleList()
for i in reversed(range(n_blocks)):
block_hidden_dim = 32 * 2 ** i
self.transition_ups.append(
TransitionUp(block_hidden_dim * 2, block_hidden_dim, block_hidden_dim))
self.transformers.append(PointTransformerBlock(
block_hidden_dim, n_neighbors, transformer_dim))
self.out = nn.Sequential(
nn.Linear(32+16, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, out_classes)
)
def forward(self, x, cat_vec=None):
_, hidden_state = self.backbone(x)
pos, h = hidden_state[-1]
h, _ = self.ptb(self.fc(h), pos)
for i in range(self.n_blocks):
h = self.transition_ups[i](
pos, h, hidden_state[- i - 2][0], hidden_state[- i - 2][1])
pos = hidden_state[- i - 2][0]
h, _ = self.transformers[i](h, pos)
return self.out(torch.cat([h, cat_vec], dim=-1))
class PartSegLoss(nn.Module):
def __init__(self, eps=0.2):
super(PartSegLoss, self).__init__()
self.eps = eps
self.loss = nn.CrossEntropyLoss()
def forward(self, logits, y):
num_classes = logits.shape[1]
logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)
loss = self.loss(logits, y)
return loss
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
'''
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc
from point_transformer import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import time
torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--opt', type=str, default='adam')
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
loss_f = nn.CrossEntropyLoss()
start_time = time.time()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
opt.zero_grad()
logits = net(data)
loss = loss_f(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
num_batches += 1
count += num_examples
loss = loss.item()
correct = (preds == label).sum().item()
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, time.time() - start_time))
scheduler.step()
def evaluate(net, test_loader, dev):
net.eval()
total_correct = 0
count = 0
start_time = time.time()
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label in tq:
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
logits = net(data)
_, preds = logits.max(1)
correct = (preds == label).sum().item()
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time))
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = PointTransformerCLS(40, batch_size, feature_dim=6)
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'adam':
# The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01,
momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1)
elif args.opt == 'sgd':
# The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam(
net.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
best_test_acc = 0
for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch))
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0:
test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc:
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
print()
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from functools import partial
import tqdm
import argparse
import time
from ShapeNet import ShapeNet
from point_transformer import PointTransformerSeg, PartSegLoss
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--tensorboard', action='store_true')
parser.add_argument('--opt', type=str, default='adam')
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
def collate(samples):
graphs, cat = map(list, zip(*samples))
return dgl.batch(graphs), cat
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
start = time.time()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label, cat in tq:
num_examples = data.shape[0]
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long).view(-1)
opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(num_examples, -1, 16)
logits = net(data, cat_tensor).permute(0, 2, 1)
loss = L(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
count += num_examples * 2048
loss = loss.item()
total_loss += loss
num_batches += 1
correct = (preds.view(-1) == label).sum().item()
total_correct += correct
AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count
tq.set_postfix({
'AvgLoss': '%.5f' % AvgLoss,
'AvgAcc': '%.5f' % AvgAcc})
scheduler.step()
end = time.time()
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, end - start))
return data, preds, AvgLoss, AvgAcc, end-start
def mIoU(preds, label, cat, cat_miou, seg_classes):
for i in range(preds.shape[0]):
shape_iou = 0
n = len(seg_classes[cat[i]])
for cls in seg_classes[cat[i]]:
pred_set = set(np.where(preds[i, :] == cls)[0])
label_set = set(np.where(label[i, :] == cls)[0])
union = len(pred_set.union(label_set))
inter = len(pred_set.intersection(label_set))
if union == 0:
shape_iou += 1
else:
shape_iou += inter / union
shape_iou /= n
cat_miou[cat[i]][0] += shape_iou
cat_miou[cat[i]][1] += 1
return cat_miou
def evaluate(net, test_loader, dev, per_cat_verbose=False):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
net.eval()
cat_miou = {}
for k in shapenet.seg_classes.keys():
cat_miou[k] = [0, 0]
miou = 0
count = 0
per_cat_miou = 0
per_cat_count = 0
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label, cat in tq:
num_examples = data.shape[0]
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float).repeat(1, 2048)
cat_tensor = cat_tensor.view(
num_examples, -1, 16)
logits = net(data, cat_tensor).permute(0, 2, 1)
_, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat, cat_miou, shapenet.seg_classes)
for _, v in cat_miou.items():
if v[1] > 0:
miou += v[0]
count += v[1]
per_cat_miou += v[0] / v[1]
per_cat_count += 1
tq.set_postfix({
'mIoU': '%.5f' % (miou / count),
'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)})
print("[Test] mIoU: %.5f, per Category mIoU: %.5f" %
(miou / count, per_cat_miou / per_cat_count))
if per_cat_verbose:
print("-" * 60)
print("Per-Category mIoU:")
for k, v in cat_miou.items():
if v[1] > 0:
print("%s mIoU=%.5f" % (k, v[0] / v[1]))
else:
print("%s mIoU=%.5f" % (k, 1))
print("-" * 60)
return miou / count, per_cat_miou / per_cat_count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = PointTransformerSeg(50, batch_size)
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'adam':
# The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01,
momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1)
elif args.opt == 'sgd':
# The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam(
net.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=1e-4
)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.3)
L = PartSegLoss()
shapenet = ShapeNet(2048, normal_channel=False)
train_loader = CustomDataLoader(shapenet.trainval())
test_loader = CustomDataLoader(shapenet.test())
# Tensorboard
if args.tensorboard:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
# Select 50 distinct colors for different parts
color_map = torch.tensor([
[47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [
128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113],
[205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [
0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139],
[10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [
255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122],
[220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [
218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237],
[128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [
175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180],
])
# paint each point according to its pred
def paint(batched_points):
B, N = batched_points.shape
colored = color_map[batched_points].squeeze(2)
return colored
best_test_miou = 0
best_test_per_cat_miou = 0
for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch))
data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 5 == 0 or epoch == 0:
test_miou, test_per_cat_miou = evaluate(
net, test_loader, dev, True)
if test_miou > best_test_miou:
best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
# Tensorboard
if args.tensorboard:
colored = paint(preds)
writer.add_mesh('data', vertices=data,
colors=colored, global_step=epoch)
writer.add_scalar('training time for one epoch',
training_time, global_step=epoch)
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch)
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch)
writer.add_scalar('best test mIoU',
best_test_miou, global_step=epoch)
print()
......@@ -5,7 +5,7 @@ from torch.autograd import Variable
import numpy as np
import dgl
import dgl.function as fn
from dgl.geometry.pytorch import farthest_point_sampler
from dgl.geometry import farthest_point_sampler # dgl.geometry.pytorch -> dgl.geometry
'''
Part of the code are adapted from
......
......@@ -247,6 +247,6 @@ def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc
from pointnet2 import PointNet2SSGCls, PointNet2MSGCls
from pointnet_cls import PointNetCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import urllib
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
torch.backends.cudnn.enabled = False
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
from functools import partial
import tqdm
import urllib
import os
import argparse
# from dataset import ModelNet
import provider
from ModelNetDataLoader import ModelNetDataLoader
from pointnet_cls import PointNetCls
from pointnet2 import PointNet2SSGCls, PointNet2MSGCls
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='pointnet')
......@@ -34,7 +34,8 @@ batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(get_download_dir(), 'modelnet40_normal_resampled')
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
......@@ -50,6 +51,7 @@ CustomDataLoader = partial(
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
net.train()
......@@ -63,7 +65,8 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
......@@ -91,6 +94,7 @@ def train(net, opt, scheduler, train_loader, dev):
'AvgAcc': '%.5f' % (total_correct / count)})
scheduler.step()
def evaluate(net, test_loader, dev):
net.eval()
......@@ -100,7 +104,7 @@ def evaluate(net, test_loader, dev):
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label in tq:
label = label[:,0]
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
logits = net(data)
......@@ -115,6 +119,7 @@ def evaluate(net, test_loader, dev):
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'pointnet':
......@@ -134,8 +139,10 @@ scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
best_test_acc = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment