Unverified Commit 3fef5d27 authored by esang's avatar esang Committed by GitHub
Browse files

[Model] PCT (#3339)



* publish pct

* add train_cls

* add readme

* update opt for point transformer

* update the example index

* update for comments
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent e7ea0f53
...@@ -20,7 +20,9 @@ To quickly locate the examples of your interest, search for the tagged keywords ...@@ -20,7 +20,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name="point_transformer"></a> Zhao et al. Point Transformer. [Paper link](http://arxiv.org/abs/2012.09164). - <a name="point_transformer"></a> Zhao et al. Point Transformer. [Paper link](http://arxiv.org/abs/2012.09164).
- Example code: [PyTorch](../examples/pytorch/pointcloud/point_transformer) - Example code: [PyTorch](../examples/pytorch/pointcloud/point_transformer)
- Tags: point cloud classification, point cloud part-segmentation - Tags: point cloud classification, point cloud part-segmentation
- <a name="pct"></a> Guo et al. PCT: Point cloud transformer. [Paper link](http://arxiv.org/abs/2012.09688).
- Example code: [PyTorch](../examples/pytorch/pointcloud/pct)
- Tags: point cloud classification, point cloud part-segmentation
## 2020 ## 2020
- <a name="eeg-gcnn"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html). - <a name="eeg-gcnn"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html).
......
import numpy as np
import warnings
import os
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
1. Initialize the sample set S with a random point
2. Pick point P not in S, which maximizes the distance d(P, S)
3. Repeat step 2 until |S| = npoint
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
"""
Input:
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
normal_channel: whether to use additional channel
cache_size: the cache size of in-memory point clouds
"""
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
def __getitem__(self, index):
return self._get_item(index)
PCT
====
This is a reproduction of the paper: [PCT: Point cloud transformer](http://arxiv.org/abs/2012.09688).
# Performance
| Task | Dataset | Metric | Score - Paper | Score - DGL (Adam) | Time(s) - DGL |
|-----------------|------------|----------|------------------|-------------|-------------------|
| Classification | ModelNet40 | Accuracy | 93.2 | 92.1 | 740.0 |
| Part Segmentation | ShapeNet | mIoU | 86.4 | 85.6 | 390.0 |
+ Time(s) are the average training time per epoch, measured on EC2 g4dn.12xlarge instance w/ Tesla T4 GPU.
+ We run the code with the preprocessing used in [PointNet++](../pointnet). We can only get 84.5 for classification if we use the preprocessing described in the paper:
> During training, a random translation in [−0.2, 0.2], a random anisotropic scaling in [0.67, 1.5] and a random input dropout were applied to augment the input data.
# How to Run
For point cloud classification, run with
```python
python train_cls.py
```
For point cloud part-segmentation, run with
```python
python train_partseg.py
```
import os, json, tqdm
import numpy as np
import dgl
from zipfile import ZipFile
from torch.utils.data import Dataset
from scipy.sparse import csr_matrix
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points
self.normal_channel = normal_channel
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path):
download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)
with ZipFile(local_path) as z:
z.extractall(path=download_path)
synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
self.synset_dict = {}
for syn in synset:
self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3],
'Bag': [4, 5],
'Cap': [6, 7],
'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15],
'Earphone': [16, 17, 18],
'Guitar': [19, 20, 21],
'Knife': [22, 23],
'Lamp': [24, 25, 26, 27],
'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35],
'Mug': [36, 37],
'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43],
'Skateboard': [44, 45, 46],
'Table': [47, 48, 49]}
train_split_json = 'shuffled_train_file_list.json'
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json'
split_path = os.path.join(data_path, 'train_test_split')
with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)
def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)
def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)
def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)
class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True):
super(ShapeNetDataset, self).__init__()
self.mode = mode
self.num_points = num_points
if not normal_channel:
self.dim = 3
else:
self.dim = 6
if mode == 'train':
self.file_list = shapenet.train_file_list
elif mode == 'valid':
self.file_list = shapenet.val_file_list
elif mode == 'test':
self.file_list = shapenet.test_file_list
elif mode == 'trainval':
self.file_list = shapenet.train_file_list + shapenet.val_file_list
else:
raise "Not supported `mode`"
data_list = []
label_list = []
category_list = []
print('Loading data from split ' + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float)
data_list.append(data[:, 0:self.dim])
label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
self.data = data_list
self.label = label_list
self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
return x
def __len__(self):
return len(self.data)
def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
x = self.data[i][inds,:self.dim]
y = self.label[i][inds]
cat = self.category[i]
if self.mode == 'train':
x = self.translate(x, size=self.dim)
x = x.astype(np.float)
y = y.astype(np.int)
return x, y, cat
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler
'''
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
def square_distance(src, dst):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(
device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
class KNearNeighbors(nn.Module):
'''
Find the k nearest neighbors
'''
def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor]
return group_idx
class KNNGraphBuilder(nn.Module):
'''
Build NN graph
'''
def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__()
self.n_neighbor = n_neighbor
self.knn = KNearNeighbors(n_neighbor)
def forward(self, pos, centroids, feat=None):
dev = pos.device
group_idx = self.knn(pos, centroids)
B, N, _ = pos.shape
glist = []
for i in range(B):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor,
src.shape[0] // centroids.shape[1])).view(-1)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class KNNMessage(nn.Module):
'''
Compute the input feature from neighbors
'''
def __init__(self, n_neighbor):
super(KNNMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
norm = edges.src['feat'] - edges.dst['feat']
if 'feat' in edges.src:
res = torch.cat([norm, edges.src['feat']], 1)
else:
res = norm
return {'agg_feat': res}
class KNNConv(nn.Module):
'''
Feature aggregation
'''
def __init__(self, sizes):
super(KNNConv, self).__init__()
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(
shape[0], -1, shape[1], shape[2]).permute(0, 3, 2, 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
class TransitionDown(nn.Module):
"""
The Transition Down Module
"""
def __init__(self, in_channels, out_channels, n_neighbor=64):
super(TransitionDown, self).__init__()
self.frnn_graph = KNNGraphBuilder(n_neighbor)
self.message = KNNMessage(n_neighbor)
self.conv = KNNConv([in_channels, out_channels, out_channels])
def forward(self, pos, feat, n_point):
batch_size = pos.shape[0]
centroids = farthest_point_sampler(pos, n_point)
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(
batch_size, -1, feat_dim)
return pos_res, feat_res
import torch
from torch import nn
from helper import TransitionDown
'''
Part of the code are adapted from
https://github.com/MenghaoGuo/PCT
'''
class PCTPositionEmbedding(nn.Module):
def __init__(self, channels=256):
super(PCTPositionEmbedding, self).__init__()
self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
self.conv_pos = nn.Conv1d(3, channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(channels)
self.sa1 = SALayerCLS(channels)
self.sa2 = SALayerCLS(channels)
self.sa3 = SALayerCLS(channels)
self.sa4 = SALayerCLS(channels)
self.relu = nn.ReLU()
def forward(self, x, xyz):
# add position embedding
xyz = xyz.permute(0, 2, 1)
xyz = self.conv_pos(xyz)
x = self.relu(self.bn1(self.conv1(x))) # B, D, N
x1 = self.sa1(x, xyz)
x2 = self.sa2(x1, xyz)
x3 = self.sa3(x2, xyz)
x4 = self.sa4(x3, xyz)
x = torch.cat((x1, x2, x3, x4), dim=1)
return x
class SALayerCLS(nn.Module):
def __init__(self, channels):
super(SALayerCLS, self).__init__()
self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.q_conv.weight = self.k_conv.weight
self.v_conv = nn.Conv1d(channels, channels, 1)
self.trans_conv = nn.Conv1d(channels, channels, 1)
self.after_norm = nn.BatchNorm1d(channels)
self.act = nn.ReLU()
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, xyz):
x = x + xyz
x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
x_k = self.k_conv(x) # b, c, n
x_v = self.v_conv(x)
energy = torch.bmm(x_q, x_k) # b, n, n
attention = self.softmax(energy)
attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
x_r = torch.bmm(x_v, attention) # b, c, n
x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
x = x + x_r
return x
class SALayerSeg(nn.Module):
def __init__(self, channels):
super(SALayerSeg, self).__init__()
self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.q_conv.weight = self.k_conv.weight
self.v_conv = nn.Conv1d(channels, channels, 1)
self.trans_conv = nn.Conv1d(channels, channels, 1)
self.after_norm = nn.BatchNorm1d(channels)
self.act = nn.ReLU()
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
x_k = self.k_conv(x) # b, c, n
x_v = self.v_conv(x)
energy = torch.bmm(x_q, x_k) # b, n, n
attention = self.softmax(energy)
attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
x_r = torch.bmm(x_v, attention) # b, c, n
x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
x = x + x_r
return x
class PointTransformerCLS(nn.Module):
def __init__(self, output_channels=40):
super(PointTransformerCLS, self).__init__()
self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(64)
self.g_op0 = TransitionDown(in_channels=128, out_channels=128, n_neighbor=32)
self.g_op1 = TransitionDown(in_channels=256, out_channels=256, n_neighbor=32)
self.pt_last = PCTPositionEmbedding()
self.relu = nn.ReLU()
self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(1024, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512)
self.dp1 = nn.Dropout(p=0.5)
self.linear2 = nn.Linear(512, 256)
self.bn7 = nn.BatchNorm1d(256)
self.dp2 = nn.Dropout(p=0.5)
self.linear3 = nn.Linear(256, output_channels)
def forward(self, x):
xyz = x[..., :3]
x = x[..., 3:].permute(0, 2, 1)
batch_size, _, _ = x.size()
x = self.relu(self.bn1(self.conv1(x))) # B, D, N
x = self.relu(self.bn2(self.conv2(x))) # B, D, N
x = x.permute(0, 2, 1)
new_xyz, feature_0 = self.g_op0(xyz, x, n_point=512)
new_xyz, feature_1 = self.g_op1(new_xyz, feature_0, n_point=256)
# add position embedding on each layer
x = self.pt_last(feature_1, new_xyz)
x = torch.cat([x, feature_1], dim=1)
x = self.conv_fuse(x)
x, _ = torch.max(x, 2)
x = x.view(batch_size, -1)
x = self.relu(self.bn6(self.linear1(x)))
x = self.dp1(x)
x = self.relu(self.bn7(self.linear2(x)))
x = self.dp2(x)
x = self.linear3(x)
return x
class PointTransformerSeg(nn.Module):
def __init__(self, part_num=50):
super(PointTransformerSeg, self).__init__()
self.part_num = part_num
self.conv1 = nn.Conv1d(3, 128, kernel_size=1, bias=False)
self.conv2 = nn.Conv1d(128, 128, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(128)
self.bn2 = nn.BatchNorm1d(128)
self.sa1 = SALayerSeg(128)
self.sa2 = SALayerSeg(128)
self.sa3 = SALayerSeg(128)
self.sa4 = SALayerSeg(128)
self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2))
self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
nn.BatchNorm1d(64),
nn.LeakyReLU(negative_slope=0.2))
self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1)
self.dp1 = nn.Dropout(0.5)
self.convs2 = nn.Conv1d(512, 256, 1)
self.convs3 = nn.Conv1d(256, self.part_num, 1)
self.bns1 = nn.BatchNorm1d(512)
self.bns2 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()
def forward(self, x, cls_label):
x = x.permute(0, 2, 1)
batch_size, _, N = x.size()
x = self.relu(self.bn1(self.conv1(x))) # B, D, N
x = self.relu(self.bn2(self.conv2(x)))
x1 = self.sa1(x)
x2 = self.sa2(x1)
x3 = self.sa3(x2)
x4 = self.sa4(x3)
x = torch.cat((x1, x2, x3, x4), dim=1)
x = self.conv_fuse(x)
x_max, _ = torch.max(x, 2)
x_avg = torch.mean(x, 2)
x_max_feature = x_max.view(
batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
x_avg_feature = x_avg.view(
batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
cls_label_feature = self.label_conv(cls_label).repeat(1, 1, N)
x_global_feature = torch.cat(
(x_max_feature, x_avg_feature, cls_label_feature), 1)
x = torch.cat((x, x_global_feature), 1)
x = self.relu(self.bns1(self.convs1(x)))
x = self.dp1(x)
x = self.relu(self.bns2(self.convs2(x)))
x = self.convs3(x)
return x
class PartSegLoss(nn.Module):
def __init__(self, eps=0.2):
super(PartSegLoss, self).__init__()
self.eps = eps
self.loss = nn.CrossEntropyLoss()
def forward(self, logits, y):
num_classes = logits.shape[1]
logits = logits.permute(0, 2, 1).contiguous().view(-1, num_classes)
loss = self.loss(logits, y)
return loss
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
'''
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc
from pct import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import time
torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
loss_f = nn.CrossEntropyLoss()
start_time = time.time()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
opt.zero_grad()
logits = net(data)
loss = loss_f(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
num_batches += 1
count += num_examples
loss = loss.item()
correct = (preds == label).sum().item()
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, time.time() - start_time))
scheduler.step()
def evaluate(net, test_loader, dev):
net.eval()
total_correct = 0
count = 0
start_time = time.time()
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label in tq:
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
logits = net(data)
_, preds = logits.max(1)
correct = (preds == label).sum().item()
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time))
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = PointTransformerCLS()
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = torch.optim.SGD(
net.parameters(),
lr=0.01,
weight_decay=1e-4,
momentum=0.9
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.num_epochs)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
best_test_acc = 0
for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch))
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0:
test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc:
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
print()
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from functools import partial
import tqdm
import argparse
import time
import provider
from ShapeNet import ShapeNet
from pct import PointTransformerSeg, PartSegLoss
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=500)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--tensorboard', action='store_true')
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
def collate(samples):
graphs, cat = map(list, zip(*samples))
return dgl.batch(graphs), cat
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
start = time.time()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label, cat in tq:
num_examples = data.shape[0]
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long).view(-1)
opt.zero_grad()
cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float)
cat_tensor = cat_tensor.view(num_examples, 16, 1)
logits = net(data, cat_tensor)
loss = L(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
count += num_examples * 2048
loss = loss.item()
total_loss += loss
num_batches += 1
correct = (preds.view(-1) == label).sum().item()
total_correct += correct
AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count
tq.set_postfix({
'AvgLoss': '%.5f' % AvgLoss,
'AvgAcc': '%.5f' % AvgAcc})
scheduler.step()
end = time.time()
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, end - start))
return data, preds, AvgLoss, AvgAcc, end-start
def mIoU(preds, label, cat, cat_miou, seg_classes):
for i in range(preds.shape[0]):
shape_iou = 0
n = len(seg_classes[cat[i]])
for cls in seg_classes[cat[i]]:
pred_set = set(np.where(preds[i, :] == cls)[0])
label_set = set(np.where(label[i, :] == cls)[0])
union = len(pred_set.union(label_set))
inter = len(pred_set.intersection(label_set))
if union == 0:
shape_iou += 1
else:
shape_iou += inter / union
shape_iou /= n
cat_miou[cat[i]][0] += shape_iou
cat_miou[cat[i]][1] += 1
return cat_miou
def evaluate(net, test_loader, dev, per_cat_verbose=False):
category_list = sorted(list(shapenet.seg_classes.keys()))
eye_mat = np.eye(16)
net.eval()
cat_miou = {}
for k in shapenet.seg_classes.keys():
cat_miou[k] = [0, 0]
miou = 0
count = 0
per_cat_miou = 0
per_cat_count = 0
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label, cat in tq:
num_examples = data.shape[0]
data = data.to(dev, dtype=torch.float)
label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float)
cat_tensor = cat_tensor.view(
num_examples, 16, 1)
logits = net(data, cat_tensor)
_, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat, cat_miou, shapenet.seg_classes)
for _, v in cat_miou.items():
if v[1] > 0:
miou += v[0]
count += v[1]
per_cat_miou += v[0] / v[1]
per_cat_count += 1
tq.set_postfix({
'mIoU': '%.5f' % (miou / count),
'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)})
print("[Test] mIoU: %.5f, per Category mIoU: %.5f" %
(miou / count, per_cat_miou / per_cat_count))
if per_cat_verbose:
print("-" * 60)
print("Per-Category mIoU:")
for k, v in cat_miou.items():
if v[1] > 0:
print("%s mIoU=%.5f" % (k, v[0] / v[1]))
else:
print("%s mIoU=%.5f" % (k, 1))
print("-" * 60)
return miou / count, per_cat_miou / per_cat_count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = PointTransformerSeg()
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = torch.optim.SGD(
net.parameters(),
lr=0.01,
weight_decay=1e-4,
momentum=0.9
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.num_epochs)
L = PartSegLoss()
shapenet = ShapeNet(2048, normal_channel=False)
train_loader = CustomDataLoader(shapenet.trainval())
test_loader = CustomDataLoader(shapenet.test())
# Tensorboard
if args.tensorboard:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
# Select 50 distinct colors for different parts
color_map = torch.tensor([
[47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [
128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113],
[205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [
0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139],
[10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [
255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122],
[220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [
218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237],
[128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [
175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180],
])
# paint each point according to its pred
def paint(batched_points):
B, N = batched_points.shape
colored = color_map[batched_points].squeeze(2)
return colored
best_test_miou = 0
best_test_per_cat_miou = 0
for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch))
data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 5 == 0 or epoch == 0:
test_miou, test_per_cat_miou = evaluate(
net, test_loader, dev, True)
if test_miou > best_test_miou:
best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
# Tensorboard
if args.tensorboard:
colored = paint(preds)
writer.add_mesh('data', vertices=data,
colors=colored, global_step=epoch)
writer.add_scalar('training time for one epoch',
training_time, global_step=epoch)
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch)
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch)
writer.add_scalar('best test mIoU',
best_test_miou, global_step=epoch)
print()
...@@ -125,13 +125,13 @@ net = net.to(dev) ...@@ -125,13 +125,13 @@ net = net.to(dev)
if args.load_model_path: if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'adam': if args.opt == 'sgd':
# The optimizer strategy described in paper: # The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01, opt = torch.optim.SGD(net.parameters(), lr=0.01,
momentum=0.9, weight_decay=1e-4) momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1) opt, milestones=[120, 160], gamma=0.1)
elif args.opt == 'sgd': elif args.opt == 'adam':
# The optimizer strategy proposed by # The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers: # https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam( opt = torch.optim.Adam(
......
...@@ -168,13 +168,13 @@ net = net.to(dev) ...@@ -168,13 +168,13 @@ net = net.to(dev)
if args.load_model_path: if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
if args.opt == 'adam': if args.opt == 'sgd':
# The optimizer strategy described in paper: # The optimizer strategy described in paper:
opt = torch.optim.SGD(net.parameters(), lr=0.01, opt = torch.optim.SGD(net.parameters(), lr=0.01,
momentum=0.9, weight_decay=1e-4) momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
opt, milestones=[120, 160], gamma=0.1) opt, milestones=[120, 160], gamma=0.1)
elif args.opt == 'sgd': elif args.opt == 'adam':
# The optimizer strategy proposed by # The optimizer strategy proposed by
# https://github.com/qq456cvb/Point-Transformers: # https://github.com/qq456cvb/Point-Transformers:
opt = torch.optim.Adam( opt = torch.optim.Adam(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment