Unverified Commit 84efe00e authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feature] Add pointnet2 msg and refactor pointnets (#82)



* add op fps with distance

* add op fps with distance

* modify F-DFS unittest

* modify sa module

* modify sa module

* SA Module support D-FPS and F-FPS

* modify sa module

* update points sa module

* modify point_sa_module

* modify point sa module

* reconstruct FPS

* reconstruct FPS

* modify docstring

* add pointnet2-sa-msg backbone

* modify pointnet2_sa_msg

* fix merge conflicts

* format tests/test_backbones.py

* [Refactor]: Add registry for PointNet2Modules

* modify h3dnet for base pointnet

* fix docstring tweaks

* fix bugs for config unittest
Co-authored-by: default avatarZwwWayne <wayne.zw@outlook.com>
parent dde4b02c
......@@ -19,6 +19,7 @@ primitive_z_cfg = dict(
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=1024,
radius=0.3,
num_sample=16,
......@@ -76,6 +77,7 @@ primitive_xy_cfg = dict(
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=1024,
radius=0.3,
num_sample=16,
......@@ -133,6 +135,7 @@ primitive_line_cfg = dict(
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=1024,
radius=0.3,
num_sample=16,
......@@ -169,50 +172,6 @@ primitive_line_cfg = dict(
num_point_line=10,
line_thresh=0.2))
proposal_module_cfg = dict(
suface_matching_cfg=dict(
num_point=256 * 6,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 6, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
line_matching_cfg=dict(
num_point=256 * 12,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
train_cfg=dict(
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
label_surface_threshold=0.3,
mask_line_threshold=0.3,
label_line_threshold=0.3),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))
model = dict(
type='H3DNet',
backbone=dict(
......@@ -232,7 +191,11 @@ model = dict(
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max')),
sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True))),
rpn_head=dict(
type='VoteHead',
vote_moudule_cfg=dict(
......@@ -249,6 +212,7 @@ model = dict(
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=256,
radius=0.3,
num_sample=16,
......@@ -286,8 +250,27 @@ model = dict(
type='H3DBboxHead',
gt_per_seed=3,
num_proposal=256,
proposal_module_cfg=proposal_module_cfg,
suface_matching_cfg=dict(
type='PointSAModule',
num_point=256 * 6,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 6, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
line_matching_cfg=dict(
type='PointSAModule',
num_point=256 * 12,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
......@@ -310,13 +293,39 @@ model = dict(
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1))))
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))))
# model training and testing settings
train_cfg = dict(
rpn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'),
rpn_proposal=dict(use_nms=False),
rcnn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'))
rcnn=dict(
pos_distance_thr=0.3,
neg_distance_thr=0.6,
sample_mod='vote',
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
label_surface_threshold=0.3,
mask_line_threshold=0.3,
label_line_threshold=0.3))
test_cfg = dict(
rpn=dict(
......
......@@ -10,7 +10,11 @@ model = dict(
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)),
bbox_head=dict(
type='VoteHead',
vote_moudule_cfg=dict(
......@@ -27,6 +31,7 @@ model = dict(
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=256,
radius=0.3,
num_sample=16,
......
......@@ -16,4 +16,4 @@ We implement H3DNet and provide the result and checkpoints on ScanNet datasets.
### ScanNet
| Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: |
| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238.log.json)|
| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01||
......@@ -58,7 +58,11 @@ model = dict(
(128, 128, 256)), # Out channels of each mlp in SA module
fp_channels=((256, 256), (256, 256)), # Out channels of each mlp in FP module
norm_cfg=dict(type='BN2d'), # Config of normalization layer
pool_mod='max'), # Pool method ('max' or 'avg') for SA modules
sa_cfg=dict( # Config of point set abstraction (SA) module
type='PointSAModule', # type of SA module
pool_mod='max', # Pool method ('max' or 'avg') for SA modules
use_xyz=True, # Whether to use xyz as features during feature gathering
normalize_xyz=True)), # Whether to use normalized xyz as feature during feature gathering
bbox_head=dict(
type='VoteHead', # The type of bbox head, refer to mmdet3d.models.dense_heads for more details
num_classes=18, # Number of classes for classification
......@@ -99,6 +103,7 @@ model = dict(
reduction='none', # Specifies the reduction to apply to the output
loss_dst_weight=10.0)), # Destination loss weight of the voting branch
vote_aggregation_cfg=dict( # Config to vote aggregation branch
type='PointSAModule', # type of vote aggregation module
num_point=256, # Number of points for the set abstraction layer in vote aggregation branch
radius=0.3, # Radius for the set abstraction layer in vote aggregation branch
num_sample=16, # Number of samples for the set abstraction layer in vote aggregation branch
......
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
from .pointnet2_sa_ssg import PointNet2SASSG
from .second import SECOND
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'PointNet2SASSG', 'MultiBackbone'
'SECOND', 'PointNet2SASSG', 'PointNet2SAMSG', 'MultiBackbone'
]
from abc import ABCMeta
from mmcv.runner import load_checkpoint
from torch import nn as nn
class BasePointNet(nn.Module, metaclass=ABCMeta):
"""Base class for PointNet."""
def __init__(self):
super(BasePointNet, self).__init__()
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
@staticmethod
def _split_point_feats(points):
"""Split coordinates and features of input points.
Args:
points (torch.Tensor): Point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
"""
xyz = points[..., 0:3].contiguous()
if points.size(-1) > 3:
features = points[..., 3:].transpose(1, 2).contiguous()
else:
features = None
return xyz, features
......@@ -74,6 +74,7 @@ class MultiBackbone(nn.Module):
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=True,
inplace=True))
def init_weights(self, pretrained=None):
......
import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from mmdet3d.ops import build_sa_module
from mmdet.models import BACKBONES
from .base_pointnet import BasePointNet
@BACKBONES.register_module()
class PointNet2SAMSG(BasePointNet):
"""PointNet2 with Multi-scale grouping.
Args:
in_channels (int): Input channels of point cloud.
num_points (tuple[int]): The number of points which each SA
module samples.
radii (tuple[float]): Sampling radii of each SA module.
num_samples (tuple[int]): The number of samples for ball
query in each SA module.
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features.
fps_mods (tuple[int]): Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
out_indices (Sequence[int]): Output from which stages.
norm_cfg (dict): Config of normalization layer.
sa_cfg (dict): Config of set abstraction module, which may contain
the following keys and values:
- pool_mod (str): Pool method ('max' or 'avg') for SA modules.
- use_xyz (bool): Whether to use xyz as a part of features.
- normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
"""
def __init__(self,
in_channels,
num_points=(2048, 1024, 512, 256),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
((128, 128, 256), (128, 192, 256), (128, 256,
256))),
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
out_indices=(2, ),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)):
super().__init__()
self.num_sa = len(sa_channels)
self.out_indices = out_indices
assert max(out_indices) < self.num_sa
assert len(num_points) == len(radii) == len(num_samples) == len(
sa_channels) == len(aggregation_channels)
self.SA_modules = nn.ModuleList()
self.aggregation_mlps = nn.ModuleList()
sa_in_channel = in_channels - 3 # number of channels without xyz
skip_channel_list = [sa_in_channel]
for sa_index in range(self.num_sa):
cur_sa_mlps = list(sa_channels[sa_index])
sa_out_channel = 0
for radius_index in range(len(radii[sa_index])):
cur_sa_mlps[radius_index] = [sa_in_channel] + list(
cur_sa_mlps[radius_index])
sa_out_channel += cur_sa_mlps[radius_index][-1]
if isinstance(fps_mods[sa_index], tuple):
cur_fps_mod = list(fps_mods[sa_index])
else:
cur_fps_mod = list([fps_mods[sa_index]])
if isinstance(fps_sample_range_lists[sa_index], tuple):
cur_fps_sample_range_list = list(
fps_sample_range_lists[sa_index])
else:
cur_fps_sample_range_list = list(
[fps_sample_range_lists[sa_index]])
self.SA_modules.append(
build_sa_module(
num_point=num_points[sa_index],
radii=radii[sa_index],
sample_nums=num_samples[sa_index],
mlp_channels=cur_sa_mlps,
fps_mod=cur_fps_mod,
fps_sample_range_list=cur_fps_sample_range_list,
norm_cfg=norm_cfg,
cfg=sa_cfg,
bias=True))
skip_channel_list.append(sa_out_channel)
self.aggregation_mlps.append(
ConvModule(
sa_out_channel,
aggregation_channels[sa_index],
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
kernel_size=1,
bias=True))
sa_in_channel = aggregation_channels[sa_index]
def forward(self, points):
"""Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
dict[str, torch.Tensor]: Outputs of the last SA module.
- sa_xyz (torch.Tensor): The coordinates of sa features.
- sa_features (torch.Tensor): The features from the
last Set Aggregation Layers.
- sa_indices (torch.Tensor): Indices of the \
input points.
"""
xyz, features = self._split_point_feats(points)
batch, num_points = xyz.shape[:2]
indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat(
batch, 1).long()
sa_xyz = [xyz]
sa_features = [features]
sa_indices = [indices]
out_sa_xyz = []
out_sa_features = []
out_sa_indices = []
for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i])
cur_features = self.aggregation_mlps[i](cur_features)
sa_xyz.append(cur_xyz)
sa_features.append(cur_features)
sa_indices.append(
torch.gather(sa_indices[-1], 1, cur_indices.long()))
if i in self.out_indices:
out_sa_xyz.append(sa_xyz[-1])
out_sa_features.append(sa_features[-1])
out_sa_indices.append(sa_indices[-1])
return dict(
sa_xyz=out_sa_xyz,
sa_features=out_sa_features,
sa_indices=out_sa_indices)
import torch
from mmcv.runner import load_checkpoint
from torch import nn as nn
from mmdet3d.ops import PointFPModule, PointSAModule
from mmdet3d.ops import PointFPModule, build_sa_module
from mmdet.models import BACKBONES
from .base_pointnet import BasePointNet
@BACKBONES.register_module()
class PointNet2SASSG(nn.Module):
class PointNet2SASSG(BasePointNet):
"""PointNet2 with Single-scale grouping.
Args:
......@@ -20,10 +20,13 @@ class PointNet2SASSG(nn.Module):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
fp_channels (tuple[tuple[int]]): Out channels of each mlp in FP module.
norm_cfg (dict): Config of normalization layer.
pool_mod (str): Pool method ('max' or 'avg') for SA modules.
use_xyz (bool): Whether to use xyz as a part of features.
normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
sa_cfg (dict): Config of set abstraction module, which may contain
the following keys and values:
- pool_mod (str): Pool method ('max' or 'avg') for SA modules.
- use_xyz (bool): Whether to use xyz as a part of features.
- normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
"""
def __init__(self,
......@@ -35,18 +38,18 @@ class PointNet2SASSG(nn.Module):
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max',
use_xyz=True,
normalize_xyz=True):
sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)):
super().__init__()
self.num_sa = len(sa_channels)
self.num_fp = len(fp_channels)
assert len(num_points) == len(radius) == len(num_samples) == len(
sa_channels)
assert len(sa_channels) >= len(fp_channels)
assert pool_mod in ['max', 'avg']
self.SA_modules = nn.ModuleList()
sa_in_channel = in_channels - 3 # number of channels without xyz
......@@ -58,15 +61,13 @@ class PointNet2SASSG(nn.Module):
sa_out_channel = cur_sa_mlps[-1]
self.SA_modules.append(
PointSAModule(
build_sa_module(
num_point=num_points[sa_index],
radius=radius[sa_index],
num_sample=num_samples[sa_index],
mlp_channels=cur_sa_mlps,
norm_cfg=norm_cfg,
use_xyz=use_xyz,
pool_mod=pool_mod,
normalize_xyz=normalize_xyz))
cfg=sa_cfg))
skip_channel_list.append(sa_out_channel)
sa_in_channel = sa_out_channel
......@@ -82,35 +83,6 @@ class PointNet2SASSG(nn.Module):
fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop()
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
@staticmethod
def _split_point_feats(points):
"""Split coordinates and features of input points.
Args:
points (torch.Tensor): Point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
"""
xyz = points[..., 0:3].contiguous()
if points.size(-1) > 3:
features = points[..., 3:].transpose(1, 2).contiguous()
else:
features = None
return xyz, features
def forward(self, points):
"""Forward pass.
......
......@@ -8,7 +8,7 @@ from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import PointSAModule, furthest_point_sample
from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
......@@ -78,7 +78,7 @@ class VoteHead(nn.Module):
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.vote_module = VoteModule(**vote_moudule_cfg)
self.vote_aggregation = PointSAModule(**vote_aggregation_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
conv_pred_list = list()
......
......@@ -7,7 +7,7 @@ from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import PointSAModule
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
......@@ -60,7 +60,6 @@ class H3DBboxHead(nn.Module):
bbox_coder,
train_cfg=None,
test_cfg=None,
proposal_module_cfg=None,
gt_per_seed=1,
num_proposal=256,
feat_channels=(128, 128),
......@@ -114,9 +113,9 @@ class H3DBboxHead(nn.Module):
line_matching_cfg['mlp_channels'][-1]
# surface center matching
self.surface_center_matcher = PointSAModule(**suface_matching_cfg)
self.surface_center_matcher = build_sa_module(suface_matching_cfg)
# line center matching
self.line_center_matcher = PointSAModule(**line_matching_cfg)
self.line_center_matcher = build_sa_module(line_matching_cfg)
# Compute the matching scores
matching_feat_dims = suface_matching_cfg['mlp_channels'][-1]
......
......@@ -131,7 +131,7 @@ class H3DRoIHead(Base3DRoIHead):
Returns:
dict: Bbox results of one frame.
"""
sample_mod = self.test.sample_mod
sample_mod = self.test_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod)
......
......@@ -5,7 +5,7 @@ from torch.nn import functional as F
from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import PointSAModule, furthest_point_sample
from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import multi_apply
from mmdet.models import HEADS
......@@ -88,7 +88,7 @@ class PrimitiveHead(nn.Module):
vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1)
self.vote_module = VoteModule(**vote_moudule_cfg)
self.vote_aggregation = PointSAModule(**vote_aggregation_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
conv_pred_list = list()
......
......@@ -9,7 +9,8 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
grouping_operation)
from .interpolate import three_interpolate, three_nn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .pointnet_modules import PointFPModule, PointSAModule, PointSAModuleMSG
from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG,
build_sa_module)
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch,
points_in_boxes_cpu, points_in_boxes_gpu)
from .sparse_block import (SparseBasicBlock, SparseBottleneck,
......@@ -29,5 +30,5 @@ __all__ = [
'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
'points_in_boxes_batch', 'get_compiler_version',
'get_compiling_cuda_version', 'Points_Sampler'
'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module'
]
from .builder import build_sa_module
from .point_fp_module import PointFPModule
from .point_sa_module import PointSAModule, PointSAModuleMSG
__all__ = ['PointSAModuleMSG', 'PointSAModule', 'PointFPModule']
__all__ = [
'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule'
]
from .registry import SA_MODULES
def build_sa_module(cfg, *args, **kwargs):
"""Build PointNet2 set abstraction (SA) module.
Args:
cfg (None or dict): The SA module config, which should contain:
- type (str): Module type.
- module args: Args needed to instantiate an SA module.
args (argument list): Arguments passed to the `__init__`
method of the corresponding module.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of the corresponding SA module .
Returns:
nn.Module: Created SA module.
"""
if cfg is None:
cfg_ = dict(type='PointSAModule')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
module_type = cfg_.pop('type')
if module_type not in SA_MODULES:
raise KeyError(f'Unrecognized module type {module_type}')
else:
sa_module = SA_MODULES.get(module_type)
module = sa_module(*args, **kwargs, **cfg_)
return module
......@@ -5,8 +5,10 @@ from torch.nn import functional as F
from typing import List
from mmdet3d.ops import GroupAll, Points_Sampler, QueryAndGroup, gather_points
from .registry import SA_MODULES
@SA_MODULES.register_module()
class PointSAModuleMSG(nn.Module):
"""Point set abstraction module with multi-scale grouping used in
Pointnets.
......@@ -167,6 +169,7 @@ class PointSAModuleMSG(nn.Module):
return new_xyz, torch.cat(new_features_list, dim=1), indices
@SA_MODULES.register_module()
class PointSAModule(PointSAModuleMSG):
"""Point set abstraction module used in Pointnets.
......
from mmcv.utils import Registry
SA_MODULES = Registry('point_sa_module')
......@@ -61,8 +61,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
norm_cfg=dict(type='BN2d')),
dict(
type='PointNet2SASSG',
in_channels=4,
......@@ -72,8 +71,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
norm_cfg=dict(type='BN2d')),
dict(
type='PointNet2SASSG',
in_channels=4,
......@@ -83,8 +81,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
norm_cfg=dict(type='BN2d')),
dict(
type='PointNet2SASSG',
in_channels=4,
......@@ -94,8 +91,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max')
norm_cfg=dict(type='BN2d'))
])
self = build_backbone(cfg_list)
......@@ -127,8 +123,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'))
norm_cfg=dict(type='BN2d')))
self = build_backbone(cfg_dict)
self.cuda()
......@@ -156,3 +151,68 @@ def test_multi_backbone():
with pytest.raises(AssertionError):
cfg_dict['backbones'] = 'PointNet2SASSG'
build_backbone(cfg_dict)
def test_pointnet2_sa_msg():
if not torch.cuda.is_available():
pytest.skip()
cfg = dict(
type='PointNet2SAMSG',
in_channels=4,
num_points=(256, 64, (32, 32)),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((8, 8, 16), (8, 8, 16), (8, 8, 8)),
sa_channels=(((8, 8, 16), (8, 8, 16),
(8, 8, 16)), ((16, 16, 32), (16, 16, 32), (16, 24, 32)),
((32, 32, 64), (32, 24, 64), (32, 64, 64))),
aggregation_channels=(16, 32, 64),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (64, -1)),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False))
self = build_backbone(cfg)
self.cuda()
assert self.SA_modules[0].mlps[0].layer0.conv.in_channels == 4
assert self.SA_modules[0].mlps[0].layer0.conv.out_channels == 8
assert self.SA_modules[0].mlps[1].layer1.conv.out_channels == 8
assert self.SA_modules[2].mlps[2].layer2.conv.out_channels == 64
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', dtype=np.float32)
xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda() # (B, N, 6)
# test forward
ret_dict = self(xyz[:, :, :4])
sa_xyz = ret_dict['sa_xyz'][-1]
sa_features = ret_dict['sa_features'][-1]
sa_indices = ret_dict['sa_indices'][-1]
assert sa_xyz.shape == torch.Size([1, 64, 3])
assert sa_features.shape == torch.Size([1, 64, 64])
assert sa_indices.shape == torch.Size([1, 64])
# out_indices should smaller than the length of SA Modules.
with pytest.raises(AssertionError):
build_backbone(
dict(
type='PointNet2SAMSG',
in_channels=4,
num_points=(256, 64, (32, 32)),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((8, 8, 16), (8, 8, 16), (8, 8, 8)),
sa_channels=(((8, 8, 16), (8, 8, 16), (8, 8, 16)),
((16, 16, 32), (16, 16, 32), (16, 24, 32)),
((32, 32, 64), (32, 24, 64), (32, 64, 64))),
aggregation_channels=(16, 32, 64),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (64, -1)),
out_indices=(2, 3),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment