Unverified Commit 5b2d44fa authored by 彭卓清's avatar 彭卓清 Committed by GitHub
Browse files

[Model] BiPointNet (#4434)



* initial completion

* add binary version of pointnet and pointnet++

* add binary version of pointnet and pointnet++

* add binary version of pointnet and pointnet++

* add binary version of pointnet and pointnet++

* add binary version of pointnet and pointnet++

* finish README

* fix suggestions

* fix suggestions

* fix suggestions

* solve details

* resolve formatting issues

* add accuracy

* revert the changes
Co-authored-by: default avatarTong He <hetong007@gmail.com>
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent a49b96ee
......@@ -37,6 +37,10 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name='ngnn'></a> Song et al. Network In Graph Neural Network. [Paper link](https://arxiv.org/abs/2111.11638).
- Example code: [PyTorch](../examples/pytorch/ogb/ngnn)
- Tags: model-agnostic methodology, link prediction, open graph benchmark.
- <a name='bipointnet'></a>Qin et al. BiPointNet: Binary Neural Network for Point Clouds. [Paper link](https://openreview.net/forum?id=9QLRCVysdlO)
- Example code: [PyTorch](../examples/pytorch/pointcloud/bipointnet)
- Tags: point cloud classification, network binarization.
## 2020
- <a name="eeg-gcnn"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html).
......
import numpy as np
import warnings
import os
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Farthest point sampler works as follows:
1. Initialize the sample set S with a random point
2. Pick point P not in S, which maximizes the distance d(P, S)
3. Repeat step 2 until |S| = npoint
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
"""
Input:
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
normal_channel: whether to use additional channel
cache_size: the cache size of in-memory point clouds
"""
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d'%(split,len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints,:]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
## *BiPointNet: Binary Neural Network for Point Clouds*
Created by [Haotong Qin](https://htqin.github.io/), [Zhongang Cai](https://scholar.google.com/citations?user=WrDKqIAAAAAJ&hl=en), [Mingyuan Zhang](https://scholar.google.com/citations?user=2QLD4fAAAAAJ&hl=en), Yifu Ding, Haiyu Zhao, Shuai Yi, [Xianglong Liu](http://sites.nlsde.buaa.edu.cn/~xlliu/), and [Hao Su](https://cseweb.ucsd.edu/~haosu/) from Beihang University, SenseTime, and UCSD.
![prediction example](https://htqin.github.io/Imgs/ICLR/overview_v1.png)
### Introduction
This project is the official implementation of our accepted ICLR 2021 paper *BiPointNet: Binary Neural Network for Point Clouds* [[PDF]( https://openreview.net/forum?id=9QLRCVysdlO)]. To alleviate the resource constraint for real-time point cloud applications that run on edge devices, in this paper we present ***BiPointNet***, the first model binarization approach for efficient deep learning on point clouds. We first discover that the immense performance drop of binarized models for point clouds mainly stems from two challenges: aggregation-induced feature homogenization that leads to a degradation of information entropy, and scale distortion that hinders optimization and invalidates scale-sensitive structures. With theoretical justifications and in-depth analysis, our BiPointNet introduces Entropy-Maximizing Aggregation (EMA) to modulate the distribution before aggregation for the maximum information entropy, and Layer-wise Scale Recovery (LSR) to efficiently restore feature representation capacity. Extensive experiments show that BiPointNet outperforms existing binarization methods by convincing margins, at the level even comparable with the full precision counterpart. We highlight that our techniques are generic, guaranteeing significant improvements on various fundamental tasks and mainstream backbones, e.g., BiPointNet gives an impressive 14.7x speedup and 18.9x storage saving on real-world resource-constrained devices. Besides, our reasoning framework is dabnn.
### How to Run
```shell script
python train_cls.py --model ${MODEL}
```
Here, `MODEL` has two choices: `bipointnet` and `bipointnet2_ssg`
# Performance
## Classification
| Model | Dataset | Metric | Score |
| --------------- | ---------- | -------- | ----- |
| BiPointNet | ModelNet40 | Accuracy | 88.4 |
| BiPointNet2_SSG | ModelNet40 | Accuracy | 83.1 |
Because of the difference in implementation brought by the application of DGL, this version is even better than the original paper.
### Citation
If you find our work useful in your research, please consider citing:
```
@inproceedings{Qin:iclr21,
author = {Haotong Qin and Zhongang Cai and Mingyuan Zhang
and Yifu Ding and Haiyu Zhao and Shuai Yi
and Xianglong Liu and Hao Su},
title = {BiPointNet: Binary Neural Network for Point Clouds},
booktitle = {ICLR},
year = {2021}
}
```
\ No newline at end of file
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.utils import _single
from torch.autograd import Function
from torch.nn import Parameter
import dgl
class BinaryQuantize(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
out = torch.sign(input)
return out
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
grad_input = grad_output
grad_input[input[0].gt(1)] = 0
grad_input[input[0].lt(-1)] = 0
return grad_input
class BiLinearLSR(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=False, binary_act=True):
super(BiLinearLSR, self).__init__(in_features, out_features, bias=bias)
self.binary_act = binary_act
# must register a nn.Parameter placeholder for model loading
# self.register_parameter('scale', None) doesn't register None into state_dict
# so it leads to unexpected key error when loading saved model
# hence, init scale with Parameter
# however, Parameter(None) actually has size [0], not [] as a scalar
# hence, init it using the following trick
self.register_parameter('scale', Parameter(torch.Tensor([0.0]).squeeze()))
def reset_scale(self, input):
bw = self.weight
ba = input
bw = bw - bw.mean()
self.scale = Parameter((F.linear(ba, bw).std() / F.linear(torch.sign(ba), torch.sign(bw)).std()).float().to(ba.device))
# corner case when ba is all 0.0
if torch.isnan(self.scale):
self.scale = Parameter((bw.std() / torch.sign(bw).std()).float().to(ba.device))
def forward(self, input):
bw = self.weight
ba = input
bw = bw - bw.mean()
if self.scale.item() == 0.0:
self.reset_scale(input)
bw = BinaryQuantize().apply(bw)
bw = bw * self.scale
if self.binary_act:
ba = BinaryQuantize().apply(ba)
output = F.linear(ba, bw)
return output
class BiLinear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, binary_act=True):
super(BiLinear, self).__init__(in_features, out_features, bias=True)
self.binary_act = binary_act
self.output_ = None
def forward(self, input):
bw = self.weight
ba = input
bw = BinaryQuantize().apply(bw)
if self.binary_act:
ba = BinaryQuantize().apply(ba)
output = F.linear(ba, bw, self.bias)
self.output_ = output
return output
class BiConv2d(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros'):
super(BiConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode)
def forward(self, input):
bw = self.weight
ba = input
bw = bw - bw.mean()
bw = BinaryQuantize().apply(bw)
ba = BinaryQuantize().apply(ba)
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(ba, expanded_padding, mode='circular'),
bw, self.bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv2d(ba, bw, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def square_distance(src, dst):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
class FixedRadiusNearNeighbors(nn.Module):
'''
Ball Query - Find the neighbors with-in a fixed radius
'''
def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
device = pos.device
B, N, _ = pos.shape
center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :self.n_neighbor]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
class FixedRadiusNNGraph(nn.Module):
'''
Build NN graph
'''
def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)
def forward(self, pos, centroids, feat=None):
dev = pos.device
group_idx = self.frnn(pos, centroids)
B, N, _ = pos.shape
glist = []
for i in range(B):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
'''
Compute the input feature from neighbors
'''
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src['pos'] - edges.dst['pos']
if 'feat' in edges.src:
res = torch.cat([pos, edges.src['feat']], 1)
else:
res = pos
return {'agg_feat': res}
import torch
import torch.nn as nn
from basic import BiLinearLSR, BiConv2d, FixedRadiusNNGraph, RelativePositionMessage
import torch.nn.functional as F
from dgl.geometry import farthest_point_sampler
class BiPointNetConv(nn.Module):
'''
Feature aggregation
'''
def __init__(self, sizes, batch_size):
super(BiPointNetConv, self).__init__()
self.batch_size = batch_size
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(BiConv2d(sizes[i-1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
def group_all(self, pos, feat):
'''
Feature aggregation and pooling for the non-sampling layer
'''
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
h = pos
B, N, D = h.shape
_, _, C = pos.shape
new_pos = torch.zeros(B, 1, C)
h = h.permute(0, 2, 1).view(B, -1, N, 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h
class BiSAModule(nn.Module):
"""
The Set Abstraction Layer
"""
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
group_all=False):
super(BiSAModule, self).__init__()
self.group_all = group_all
if not group_all:
self.npoints = npoints
self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
self.message = RelativePositionMessage(n_neighbor)
self.conv = BiPointNetConv(mlp_sizes, batch_size)
self.batch_size = batch_size
def forward(self, pos, feat):
if self.group_all:
return self.conv.group_all(pos, feat)
centroids = farthest_point_sampler(pos, self.npoints)
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res
class BiPointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
super(BiPointNet2SSGCls, self).__init__()
self.input_dims = input_dims
self.sa_module1 = BiSAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128])
self.sa_module2 = BiSAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
self.sa_module3 = BiSAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
group_all=True)
self.mlp1 = BiLinearLSR(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(dropout_prob)
self.mlp2 = BiLinearLSR(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(dropout_prob)
self.mlp_out = BiLinearLSR(256, output_classes)
def forward(self, x):
if x.shape[-1] > 3:
pos = x[:, :, :3]
feat = x[:, :, 3:]
else:
pos = x
feat = None
pos, feat = self.sa_module1(pos, feat)
pos, feat = self.sa_module2(pos, feat)
_, h = self.sa_module3(pos, feat)
h = self.mlp1(h)
h = self.bn1(h)
h = F.relu(h)
h = self.drop1(h)
h = self.mlp2(h)
h = self.bn2(h)
h = F.relu(h)
h = self.drop2(h)
out = self.mlp_out(h)
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from basic import BiLinear
offset_map = {
1024: -3.2041,
2048: -3.4025,
4096: -3.5836
}
class Conv1d(nn.Module):
def __init__(self, inplane, outplane, Linear):
super().__init__()
self.lin = Linear(inplane, outplane)
def forward(self, x):
B, C, N = x.shape
x = x.permute(0, 2, 1).contiguous().view(-1, C)
x = self.lin(x).view(B, N, -1).permute(0, 2, 1).contiguous()
return x
class EmaMaxPool(nn.Module):
def __init__(self, kernel_size, affine=True, Linear=BiLinear, use_bn=True):
super(EmaMaxPool, self).__init__()
self.kernel_size = kernel_size
self.bn3 = nn.BatchNorm1d(1024, affine=affine)
self.use_bn = use_bn
def forward(self, x):
batchsize, D, N = x.size()
if self.use_bn:
x = torch.max(x, 2, keepdim=True)[0] + offset_map[N]
else:
x = torch.max(x, 2, keepdim=True)[0] - 0.3
return x
class BiPointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64,
use_transform=True, Linear=BiLinear):
super(BiPointNetCls, self).__init__()
self.input_dims = input_dims
self.conv1 = nn.ModuleList()
self.conv1.append(Conv1d(input_dims, conv1_dim, Linear=Linear))
self.conv1.append(Conv1d(conv1_dim, conv1_dim, Linear=Linear))
self.conv1.append(Conv1d(conv1_dim, conv1_dim, Linear=Linear))
self.bn1 = nn.ModuleList()
self.bn1.append(nn.BatchNorm1d(conv1_dim))
self.bn1.append(nn.BatchNorm1d(conv1_dim))
self.bn1.append(nn.BatchNorm1d(conv1_dim))
self.conv2 = nn.ModuleList()
self.conv2.append(Conv1d(conv1_dim, conv1_dim * 2, Linear=Linear))
self.conv2.append(Conv1d(conv1_dim * 2, conv1_dim * 16, Linear=Linear))
self.bn2 = nn.ModuleList()
self.bn2.append(nn.BatchNorm1d(conv1_dim * 2))
self.bn2.append(nn.BatchNorm1d(conv1_dim * 16))
self.maxpool = EmaMaxPool(conv1_dim * 16, Linear=Linear, use_bn=True)
self.pool_feat_len = conv1_dim * 16
self.mlp3 = nn.ModuleList()
self.mlp3.append(Linear(conv1_dim * 16, conv1_dim * 8))
self.mlp3.append(Linear(conv1_dim * 8, conv1_dim * 4))
self.bn3 = nn.ModuleList()
self.bn3.append(nn.BatchNorm1d(conv1_dim * 8))
self.bn3.append(nn.BatchNorm1d(conv1_dim * 4))
self.dropout = nn.Dropout(0.3)
self.mlp_out = Linear(conv1_dim * 4, output_classes)
self.use_transform = use_transform
if use_transform:
self.transform1 = TransformNet(input_dims)
self.trans_bn1 = nn.BatchNorm1d(input_dims)
self.transform2 = TransformNet(conv1_dim)
self.trans_bn2 = nn.BatchNorm1d(conv1_dim)
def forward(self, x):
batch_size = x.shape[0]
h = x.permute(0, 2, 1)
if self.use_transform:
trans = self.transform1(h)
h = h.transpose(2, 1)
h = torch.bmm(h, trans)
h = h.transpose(2, 1)
h = F.relu(self.trans_bn1(h))
for conv, bn in zip(self.conv1, self.bn1):
h = conv(h)
h = bn(h)
h = F.relu(h)
if self.use_transform:
trans = self.transform2(h)
h = h.transpose(2, 1)
h = torch.bmm(h, trans)
h = h.transpose(2, 1)
h = F.relu(self.trans_bn2(h))
for conv, bn in zip(self.conv2, self.bn2):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp3, self.bn3):
h = mlp(h)
h = bn(h)
h = F.relu(h)
h = self.dropout(h)
out = self.mlp_out(h)
return out
class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64, Linear=BiLinear):
super(TransformNet, self).__init__()
self.conv = nn.ModuleList()
self.conv.append(Conv1d(input_dims, conv1_dim, Linear=Linear))
self.conv.append(Conv1d(conv1_dim, conv1_dim * 2, Linear=Linear))
self.conv.append(Conv1d(conv1_dim * 2, conv1_dim * 16, Linear=Linear))
self.bn = nn.ModuleList()
self.bn.append(nn.BatchNorm1d(conv1_dim))
self.bn.append(nn.BatchNorm1d(conv1_dim * 2))
self.bn.append(nn.BatchNorm1d(conv1_dim * 16))
# self.maxpool = nn.MaxPool1d(conv1_dim * 16)
self.maxpool = EmaMaxPool(conv1_dim * 16, Linear=Linear, use_bn=True)
self.pool_feat_len = conv1_dim * 16
self.mlp2 = nn.ModuleList()
self.mlp2.append(Linear(conv1_dim * 16, conv1_dim * 8))
self.mlp2.append(Linear(conv1_dim * 8, conv1_dim * 4))
self.bn2 = nn.ModuleList()
self.bn2.append(nn.BatchNorm1d(conv1_dim * 8))
self.bn2.append(nn.BatchNorm1d(conv1_dim * 4))
self.input_dims = input_dims
self.mlp_out = Linear(conv1_dim * 4, input_dims * input_dims)
def forward(self, h):
batch_size = h.shape[0]
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h)
h = bn(h)
h = F.relu(h)
out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
if out.is_cuda:
iden = iden.cuda()
out = out + iden
out = out.view(-1, self.input_dims, self.input_dims)
return out
from bipointnet_cls import BiPointNetCls
from bipointnet2 import BiPointNet2SSGCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import urllib
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
torch.backends.cudnn.enabled = False
# from dataset import ModelNet
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='bipointnet')
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
loss_f = nn.CrossEntropyLoss()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
opt.zero_grad()
logits = net(data)
loss = loss_f(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
num_batches += 1
count += num_examples
loss = loss.item()
correct = (preds == label).sum().item()
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
scheduler.step()
def evaluate(net, test_loader, dev):
net.eval()
total_correct = 0
count = 0
with torch.no_grad():
with tqdm.tqdm(test_loader, ascii=True) as tq:
for data, label in tq:
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
logits = net(data)
_, preds = logits.max(1)
correct = (preds == label).sum().item()
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'bipointnet':
net = BiPointNetCls(40, input_dims=6)
elif args.model == 'bipointnet2_ssg':
net = BiPointNet2SSGCls(40, batch_size, input_dims=6)
net = net.to(dev)
if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
best_test_acc = 0
for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0:
print('Epoch #%d Testing' % epoch)
test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc:
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment