Unverified Commit 84efe00e authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feature] Add pointnet2 msg and refactor pointnets (#82)



* add op fps with distance

* add op fps with distance

* modify F-DFS unittest

* modify sa module

* modify sa module

* SA Module support D-FPS and F-FPS

* modify sa module

* update points sa module

* modify point_sa_module

* modify point sa module

* reconstruct FPS

* reconstruct FPS

* modify docstring

* add pointnet2-sa-msg backbone

* modify pointnet2_sa_msg

* fix merge conflicts

* format tests/test_backbones.py

* [Refactor]: Add registry for PointNet2Modules

* modify h3dnet for base pointnet

* fix docstring tweaks

* fix bugs for config unittest
Co-authored-by: default avatarZwwWayne <wayne.zw@outlook.com>
parent dde4b02c
...@@ -19,6 +19,7 @@ primitive_z_cfg = dict( ...@@ -19,6 +19,7 @@ primitive_z_cfg = dict(
reduction='none', reduction='none',
loss_dst_weight=10.0)), loss_dst_weight=10.0)),
vote_aggregation_cfg=dict( vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=1024, num_point=1024,
radius=0.3, radius=0.3,
num_sample=16, num_sample=16,
...@@ -76,6 +77,7 @@ primitive_xy_cfg = dict( ...@@ -76,6 +77,7 @@ primitive_xy_cfg = dict(
reduction='none', reduction='none',
loss_dst_weight=10.0)), loss_dst_weight=10.0)),
vote_aggregation_cfg=dict( vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=1024, num_point=1024,
radius=0.3, radius=0.3,
num_sample=16, num_sample=16,
...@@ -133,6 +135,7 @@ primitive_line_cfg = dict( ...@@ -133,6 +135,7 @@ primitive_line_cfg = dict(
reduction='none', reduction='none',
loss_dst_weight=10.0)), loss_dst_weight=10.0)),
vote_aggregation_cfg=dict( vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=1024, num_point=1024,
radius=0.3, radius=0.3,
num_sample=16, num_sample=16,
...@@ -169,50 +172,6 @@ primitive_line_cfg = dict( ...@@ -169,50 +172,6 @@ primitive_line_cfg = dict(
num_point_line=10, num_point_line=10,
line_thresh=0.2)) line_thresh=0.2))
proposal_module_cfg = dict(
suface_matching_cfg=dict(
num_point=256 * 6,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 6, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
line_matching_cfg=dict(
num_point=256 * 12,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
train_cfg=dict(
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
label_surface_threshold=0.3,
mask_line_threshold=0.3,
label_line_threshold=0.3),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))
model = dict( model = dict(
type='H3DNet', type='H3DNet',
backbone=dict( backbone=dict(
...@@ -232,7 +191,11 @@ model = dict( ...@@ -232,7 +191,11 @@ model = dict(
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
pool_mod='max')), sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True))),
rpn_head=dict( rpn_head=dict(
type='VoteHead', type='VoteHead',
vote_moudule_cfg=dict( vote_moudule_cfg=dict(
...@@ -249,6 +212,7 @@ model = dict( ...@@ -249,6 +212,7 @@ model = dict(
reduction='none', reduction='none',
loss_dst_weight=10.0)), loss_dst_weight=10.0)),
vote_aggregation_cfg=dict( vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=256, num_point=256,
radius=0.3, radius=0.3,
num_sample=16, num_sample=16,
...@@ -286,8 +250,27 @@ model = dict( ...@@ -286,8 +250,27 @@ model = dict(
type='H3DBboxHead', type='H3DBboxHead',
gt_per_seed=3, gt_per_seed=3,
num_proposal=256, num_proposal=256,
proposal_module_cfg=proposal_module_cfg, suface_matching_cfg=dict(
type='PointSAModule',
num_point=256 * 6,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 6, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
line_matching_cfg=dict(
type='PointSAModule',
num_point=256 * 12,
radius=0.5,
num_sample=32,
mlp_channels=[128 + 12, 128, 64, 32],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128), feat_channels=(128, 128),
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
...@@ -310,13 +293,39 @@ model = dict( ...@@ -310,13 +293,39 @@ model = dict(
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
semantic_loss=dict( semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=0.1)))) type='CrossEntropyLoss', reduction='sum', loss_weight=0.1),
cues_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
cues_semantic_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.3, 0.7],
reduction='mean',
loss_weight=5.0),
proposal_objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='none',
loss_weight=5.0),
primitive_center_loss=dict(
type='MSELoss', reduction='none', loss_weight=1.0))))
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'), rpn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'),
rpn_proposal=dict(use_nms=False), rpn_proposal=dict(use_nms=False),
rcnn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote')) rcnn=dict(
pos_distance_thr=0.3,
neg_distance_thr=0.6,
sample_mod='vote',
far_threshold=0.6,
near_threshold=0.3,
mask_surface_threshold=0.3,
label_surface_threshold=0.3,
mask_line_threshold=0.3,
label_line_threshold=0.3))
test_cfg = dict( test_cfg = dict(
rpn=dict( rpn=dict(
......
...@@ -10,7 +10,11 @@ model = dict( ...@@ -10,7 +10,11 @@ model = dict(
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
pool_mod='max'), sa_cfg=dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)),
bbox_head=dict( bbox_head=dict(
type='VoteHead', type='VoteHead',
vote_moudule_cfg=dict( vote_moudule_cfg=dict(
...@@ -27,6 +31,7 @@ model = dict( ...@@ -27,6 +31,7 @@ model = dict(
reduction='none', reduction='none',
loss_dst_weight=10.0)), loss_dst_weight=10.0)),
vote_aggregation_cfg=dict( vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=256, num_point=256,
radius=0.3, radius=0.3,
num_sample=16, num_sample=16,
......
...@@ -16,4 +16,4 @@ We implement H3DNet and provide the result and checkpoints on ScanNet datasets. ...@@ -16,4 +16,4 @@ We implement H3DNet and provide the result and checkpoints on ScanNet datasets.
### ScanNet ### ScanNet
| Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download | | Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download |
| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | | :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: |
| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth) &#124; [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238.log.json)| | [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01||
...@@ -58,7 +58,11 @@ model = dict( ...@@ -58,7 +58,11 @@ model = dict(
(128, 128, 256)), # Out channels of each mlp in SA module (128, 128, 256)), # Out channels of each mlp in SA module
fp_channels=((256, 256), (256, 256)), # Out channels of each mlp in FP module fp_channels=((256, 256), (256, 256)), # Out channels of each mlp in FP module
norm_cfg=dict(type='BN2d'), # Config of normalization layer norm_cfg=dict(type='BN2d'), # Config of normalization layer
pool_mod='max'), # Pool method ('max' or 'avg') for SA modules sa_cfg=dict( # Config of point set abstraction (SA) module
type='PointSAModule', # type of SA module
pool_mod='max', # Pool method ('max' or 'avg') for SA modules
use_xyz=True, # Whether to use xyz as features during feature gathering
normalize_xyz=True)), # Whether to use normalized xyz as feature during feature gathering
bbox_head=dict( bbox_head=dict(
type='VoteHead', # The type of bbox head, refer to mmdet3d.models.dense_heads for more details type='VoteHead', # The type of bbox head, refer to mmdet3d.models.dense_heads for more details
num_classes=18, # Number of classes for classification num_classes=18, # Number of classes for classification
...@@ -99,6 +103,7 @@ model = dict( ...@@ -99,6 +103,7 @@ model = dict(
reduction='none', # Specifies the reduction to apply to the output reduction='none', # Specifies the reduction to apply to the output
loss_dst_weight=10.0)), # Destination loss weight of the voting branch loss_dst_weight=10.0)), # Destination loss weight of the voting branch
vote_aggregation_cfg=dict( # Config to vote aggregation branch vote_aggregation_cfg=dict( # Config to vote aggregation branch
type='PointSAModule', # type of vote aggregation module
num_point=256, # Number of points for the set abstraction layer in vote aggregation branch num_point=256, # Number of points for the set abstraction layer in vote aggregation branch
radius=0.3, # Radius for the set abstraction layer in vote aggregation branch radius=0.3, # Radius for the set abstraction layer in vote aggregation branch
num_sample=16, # Number of samples for the set abstraction layer in vote aggregation branch num_sample=16, # Number of samples for the set abstraction layer in vote aggregation branch
......
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .multi_backbone import MultiBackbone from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
from .pointnet2_sa_ssg import PointNet2SASSG from .pointnet2_sa_ssg import PointNet2SASSG
from .second import SECOND from .second import SECOND
__all__ = [ __all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'PointNet2SASSG', 'MultiBackbone' 'SECOND', 'PointNet2SASSG', 'PointNet2SAMSG', 'MultiBackbone'
] ]
from abc import ABCMeta
from mmcv.runner import load_checkpoint
from torch import nn as nn
class BasePointNet(nn.Module, metaclass=ABCMeta):
"""Base class for PointNet."""
def __init__(self):
super(BasePointNet, self).__init__()
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
@staticmethod
def _split_point_feats(points):
"""Split coordinates and features of input points.
Args:
points (torch.Tensor): Point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
"""
xyz = points[..., 0:3].contiguous()
if points.size(-1) > 3:
features = points[..., 3:].transpose(1, 2).contiguous()
else:
features = None
return xyz, features
...@@ -74,6 +74,7 @@ class MultiBackbone(nn.Module): ...@@ -74,6 +74,7 @@ class MultiBackbone(nn.Module):
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg, act_cfg=act_cfg,
bias=True,
inplace=True)) inplace=True))
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
......
import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from mmdet3d.ops import build_sa_module
from mmdet.models import BACKBONES
from .base_pointnet import BasePointNet
@BACKBONES.register_module()
class PointNet2SAMSG(BasePointNet):
"""PointNet2 with Multi-scale grouping.
Args:
in_channels (int): Input channels of point cloud.
num_points (tuple[int]): The number of points which each SA
module samples.
radii (tuple[float]): Sampling radii of each SA module.
num_samples (tuple[int]): The number of samples for ball
query in each SA module.
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features.
fps_mods (tuple[int]): Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
out_indices (Sequence[int]): Output from which stages.
norm_cfg (dict): Config of normalization layer.
sa_cfg (dict): Config of set abstraction module, which may contain
the following keys and values:
- pool_mod (str): Pool method ('max' or 'avg') for SA modules.
- use_xyz (bool): Whether to use xyz as a part of features.
- normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
"""
def __init__(self,
in_channels,
num_points=(2048, 1024, 512, 256),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
((128, 128, 256), (128, 192, 256), (128, 256,
256))),
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
out_indices=(2, ),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)):
super().__init__()
self.num_sa = len(sa_channels)
self.out_indices = out_indices
assert max(out_indices) < self.num_sa
assert len(num_points) == len(radii) == len(num_samples) == len(
sa_channels) == len(aggregation_channels)
self.SA_modules = nn.ModuleList()
self.aggregation_mlps = nn.ModuleList()
sa_in_channel = in_channels - 3 # number of channels without xyz
skip_channel_list = [sa_in_channel]
for sa_index in range(self.num_sa):
cur_sa_mlps = list(sa_channels[sa_index])
sa_out_channel = 0
for radius_index in range(len(radii[sa_index])):
cur_sa_mlps[radius_index] = [sa_in_channel] + list(
cur_sa_mlps[radius_index])
sa_out_channel += cur_sa_mlps[radius_index][-1]
if isinstance(fps_mods[sa_index], tuple):
cur_fps_mod = list(fps_mods[sa_index])
else:
cur_fps_mod = list([fps_mods[sa_index]])
if isinstance(fps_sample_range_lists[sa_index], tuple):
cur_fps_sample_range_list = list(
fps_sample_range_lists[sa_index])
else:
cur_fps_sample_range_list = list(
[fps_sample_range_lists[sa_index]])
self.SA_modules.append(
build_sa_module(
num_point=num_points[sa_index],
radii=radii[sa_index],
sample_nums=num_samples[sa_index],
mlp_channels=cur_sa_mlps,
fps_mod=cur_fps_mod,
fps_sample_range_list=cur_fps_sample_range_list,
norm_cfg=norm_cfg,
cfg=sa_cfg,
bias=True))
skip_channel_list.append(sa_out_channel)
self.aggregation_mlps.append(
ConvModule(
sa_out_channel,
aggregation_channels[sa_index],
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
kernel_size=1,
bias=True))
sa_in_channel = aggregation_channels[sa_index]
def forward(self, points):
"""Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
dict[str, torch.Tensor]: Outputs of the last SA module.
- sa_xyz (torch.Tensor): The coordinates of sa features.
- sa_features (torch.Tensor): The features from the
last Set Aggregation Layers.
- sa_indices (torch.Tensor): Indices of the \
input points.
"""
xyz, features = self._split_point_feats(points)
batch, num_points = xyz.shape[:2]
indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat(
batch, 1).long()
sa_xyz = [xyz]
sa_features = [features]
sa_indices = [indices]
out_sa_xyz = []
out_sa_features = []
out_sa_indices = []
for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i])
cur_features = self.aggregation_mlps[i](cur_features)
sa_xyz.append(cur_xyz)
sa_features.append(cur_features)
sa_indices.append(
torch.gather(sa_indices[-1], 1, cur_indices.long()))
if i in self.out_indices:
out_sa_xyz.append(sa_xyz[-1])
out_sa_features.append(sa_features[-1])
out_sa_indices.append(sa_indices[-1])
return dict(
sa_xyz=out_sa_xyz,
sa_features=out_sa_features,
sa_indices=out_sa_indices)
import torch import torch
from mmcv.runner import load_checkpoint
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule, PointSAModule from mmdet3d.ops import PointFPModule, build_sa_module
from mmdet.models import BACKBONES from mmdet.models import BACKBONES
from .base_pointnet import BasePointNet
@BACKBONES.register_module() @BACKBONES.register_module()
class PointNet2SASSG(nn.Module): class PointNet2SASSG(BasePointNet):
"""PointNet2 with Single-scale grouping. """PointNet2 with Single-scale grouping.
Args: Args:
...@@ -20,10 +20,13 @@ class PointNet2SASSG(nn.Module): ...@@ -20,10 +20,13 @@ class PointNet2SASSG(nn.Module):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module. sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
fp_channels (tuple[tuple[int]]): Out channels of each mlp in FP module. fp_channels (tuple[tuple[int]]): Out channels of each mlp in FP module.
norm_cfg (dict): Config of normalization layer. norm_cfg (dict): Config of normalization layer.
pool_mod (str): Pool method ('max' or 'avg') for SA modules. sa_cfg (dict): Config of set abstraction module, which may contain
use_xyz (bool): Whether to use xyz as a part of features. the following keys and values:
normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module. - pool_mod (str): Pool method ('max' or 'avg') for SA modules.
- use_xyz (bool): Whether to use xyz as a part of features.
- normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
""" """
def __init__(self, def __init__(self,
...@@ -35,18 +38,18 @@ class PointNet2SASSG(nn.Module): ...@@ -35,18 +38,18 @@ class PointNet2SASSG(nn.Module):
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
pool_mod='max', sa_cfg=dict(
use_xyz=True, type='PointSAModule',
normalize_xyz=True): pool_mod='max',
use_xyz=True,
normalize_xyz=True)):
super().__init__() super().__init__()
self.num_sa = len(sa_channels) self.num_sa = len(sa_channels)
self.num_fp = len(fp_channels) self.num_fp = len(fp_channels)
assert len(num_points) == len(radius) == len(num_samples) == len( assert len(num_points) == len(radius) == len(num_samples) == len(
sa_channels) sa_channels)
assert len(sa_channels) >= len(fp_channels) assert len(sa_channels) >= len(fp_channels)
assert pool_mod in ['max', 'avg']
self.SA_modules = nn.ModuleList() self.SA_modules = nn.ModuleList()
sa_in_channel = in_channels - 3 # number of channels without xyz sa_in_channel = in_channels - 3 # number of channels without xyz
...@@ -58,15 +61,13 @@ class PointNet2SASSG(nn.Module): ...@@ -58,15 +61,13 @@ class PointNet2SASSG(nn.Module):
sa_out_channel = cur_sa_mlps[-1] sa_out_channel = cur_sa_mlps[-1]
self.SA_modules.append( self.SA_modules.append(
PointSAModule( build_sa_module(
num_point=num_points[sa_index], num_point=num_points[sa_index],
radius=radius[sa_index], radius=radius[sa_index],
num_sample=num_samples[sa_index], num_sample=num_samples[sa_index],
mlp_channels=cur_sa_mlps, mlp_channels=cur_sa_mlps,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
use_xyz=use_xyz, cfg=sa_cfg))
pool_mod=pool_mod,
normalize_xyz=normalize_xyz))
skip_channel_list.append(sa_out_channel) skip_channel_list.append(sa_out_channel)
sa_in_channel = sa_out_channel sa_in_channel = sa_out_channel
...@@ -82,35 +83,6 @@ class PointNet2SASSG(nn.Module): ...@@ -82,35 +83,6 @@ class PointNet2SASSG(nn.Module):
fp_source_channel = cur_fp_mlps[-1] fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop() fp_target_channel = skip_channel_list.pop()
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
@staticmethod
def _split_point_feats(points):
"""Split coordinates and features of input points.
Args:
points (torch.Tensor): Point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
"""
xyz = points[..., 0:3].contiguous()
if points.size(-1) > 3:
features = points[..., 3:].transpose(1, 2).contiguous()
else:
features = None
return xyz, features
def forward(self, points): def forward(self, points):
"""Forward pass. """Forward pass.
......
...@@ -8,7 +8,7 @@ from mmdet3d.core.post_processing import aligned_3d_nms ...@@ -8,7 +8,7 @@ from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import PointSAModule, furthest_point_sample from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS from mmdet.models import HEADS
...@@ -78,7 +78,7 @@ class VoteHead(nn.Module): ...@@ -78,7 +78,7 @@ class VoteHead(nn.Module):
self.num_dir_bins = self.bbox_coder.num_dir_bins self.num_dir_bins = self.bbox_coder.num_dir_bins
self.vote_module = VoteModule(**vote_moudule_cfg) self.vote_module = VoteModule(**vote_moudule_cfg)
self.vote_aggregation = PointSAModule(**vote_aggregation_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1] prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
conv_pred_list = list() conv_pred_list = list()
......
...@@ -7,7 +7,7 @@ from mmdet3d.core.bbox import DepthInstance3DBoxes ...@@ -7,7 +7,7 @@ from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import PointSAModule from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS from mmdet.models import HEADS
...@@ -60,7 +60,6 @@ class H3DBboxHead(nn.Module): ...@@ -60,7 +60,6 @@ class H3DBboxHead(nn.Module):
bbox_coder, bbox_coder,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
proposal_module_cfg=None,
gt_per_seed=1, gt_per_seed=1,
num_proposal=256, num_proposal=256,
feat_channels=(128, 128), feat_channels=(128, 128),
...@@ -114,9 +113,9 @@ class H3DBboxHead(nn.Module): ...@@ -114,9 +113,9 @@ class H3DBboxHead(nn.Module):
line_matching_cfg['mlp_channels'][-1] line_matching_cfg['mlp_channels'][-1]
# surface center matching # surface center matching
self.surface_center_matcher = PointSAModule(**suface_matching_cfg) self.surface_center_matcher = build_sa_module(suface_matching_cfg)
# line center matching # line center matching
self.line_center_matcher = PointSAModule(**line_matching_cfg) self.line_center_matcher = build_sa_module(line_matching_cfg)
# Compute the matching scores # Compute the matching scores
matching_feat_dims = suface_matching_cfg['mlp_channels'][-1] matching_feat_dims = suface_matching_cfg['mlp_channels'][-1]
......
...@@ -131,7 +131,7 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -131,7 +131,7 @@ class H3DRoIHead(Base3DRoIHead):
Returns: Returns:
dict: Bbox results of one frame. dict: Bbox results of one frame.
""" """
sample_mod = self.test.sample_mod sample_mod = self.test_cfg.sample_mod
assert sample_mod in ['vote', 'seed', 'random'] assert sample_mod in ['vote', 'seed', 'random']
result_z = self.primitive_z(feats_dict, sample_mod) result_z = self.primitive_z(feats_dict, sample_mod)
......
...@@ -5,7 +5,7 @@ from torch.nn import functional as F ...@@ -5,7 +5,7 @@ from torch.nn import functional as F
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import PointSAModule, furthest_point_sample from mmdet3d.ops import build_sa_module, furthest_point_sample
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmdet.models import HEADS from mmdet.models import HEADS
...@@ -88,7 +88,7 @@ class PrimitiveHead(nn.Module): ...@@ -88,7 +88,7 @@ class PrimitiveHead(nn.Module):
vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1) vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1)
self.vote_module = VoteModule(**vote_moudule_cfg) self.vote_module = VoteModule(**vote_moudule_cfg)
self.vote_aggregation = PointSAModule(**vote_aggregation_cfg) self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
prev_channel = vote_aggregation_cfg['mlp_channels'][-1] prev_channel = vote_aggregation_cfg['mlp_channels'][-1]
conv_pred_list = list() conv_pred_list = list()
......
...@@ -9,7 +9,8 @@ from .group_points import (GroupAll, QueryAndGroup, group_points, ...@@ -9,7 +9,8 @@ from .group_points import (GroupAll, QueryAndGroup, group_points,
grouping_operation) grouping_operation)
from .interpolate import three_interpolate, three_nn from .interpolate import three_interpolate, three_nn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .pointnet_modules import PointFPModule, PointSAModule, PointSAModuleMSG from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG,
build_sa_module)
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch, from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch,
points_in_boxes_cpu, points_in_boxes_gpu) points_in_boxes_cpu, points_in_boxes_gpu)
from .sparse_block import (SparseBasicBlock, SparseBottleneck, from .sparse_block import (SparseBasicBlock, SparseBottleneck,
...@@ -29,5 +30,5 @@ __all__ = [ ...@@ -29,5 +30,5 @@ __all__ = [
'gather_points', 'grouping_operation', 'group_points', 'GroupAll', 'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule', 'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
'points_in_boxes_batch', 'get_compiler_version', 'points_in_boxes_batch', 'get_compiler_version',
'get_compiling_cuda_version', 'Points_Sampler' 'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module'
] ]
from .builder import build_sa_module
from .point_fp_module import PointFPModule from .point_fp_module import PointFPModule
from .point_sa_module import PointSAModule, PointSAModuleMSG from .point_sa_module import PointSAModule, PointSAModuleMSG
__all__ = ['PointSAModuleMSG', 'PointSAModule', 'PointFPModule'] __all__ = [
'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule'
]
from .registry import SA_MODULES
def build_sa_module(cfg, *args, **kwargs):
"""Build PointNet2 set abstraction (SA) module.
Args:
cfg (None or dict): The SA module config, which should contain:
- type (str): Module type.
- module args: Args needed to instantiate an SA module.
args (argument list): Arguments passed to the `__init__`
method of the corresponding module.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of the corresponding SA module .
Returns:
nn.Module: Created SA module.
"""
if cfg is None:
cfg_ = dict(type='PointSAModule')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
module_type = cfg_.pop('type')
if module_type not in SA_MODULES:
raise KeyError(f'Unrecognized module type {module_type}')
else:
sa_module = SA_MODULES.get(module_type)
module = sa_module(*args, **kwargs, **cfg_)
return module
...@@ -5,8 +5,10 @@ from torch.nn import functional as F ...@@ -5,8 +5,10 @@ from torch.nn import functional as F
from typing import List from typing import List
from mmdet3d.ops import GroupAll, Points_Sampler, QueryAndGroup, gather_points from mmdet3d.ops import GroupAll, Points_Sampler, QueryAndGroup, gather_points
from .registry import SA_MODULES
@SA_MODULES.register_module()
class PointSAModuleMSG(nn.Module): class PointSAModuleMSG(nn.Module):
"""Point set abstraction module with multi-scale grouping used in """Point set abstraction module with multi-scale grouping used in
Pointnets. Pointnets.
...@@ -167,6 +169,7 @@ class PointSAModuleMSG(nn.Module): ...@@ -167,6 +169,7 @@ class PointSAModuleMSG(nn.Module):
return new_xyz, torch.cat(new_features_list, dim=1), indices return new_xyz, torch.cat(new_features_list, dim=1), indices
@SA_MODULES.register_module()
class PointSAModule(PointSAModuleMSG): class PointSAModule(PointSAModuleMSG):
"""Point set abstraction module used in Pointnets. """Point set abstraction module used in Pointnets.
......
from mmcv.utils import Registry
SA_MODULES = Registry('point_sa_module')
...@@ -61,8 +61,7 @@ def test_multi_backbone(): ...@@ -61,8 +61,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d')),
pool_mod='max'),
dict( dict(
type='PointNet2SASSG', type='PointNet2SASSG',
in_channels=4, in_channels=4,
...@@ -72,8 +71,7 @@ def test_multi_backbone(): ...@@ -72,8 +71,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d')),
pool_mod='max'),
dict( dict(
type='PointNet2SASSG', type='PointNet2SASSG',
in_channels=4, in_channels=4,
...@@ -83,8 +81,7 @@ def test_multi_backbone(): ...@@ -83,8 +81,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d')),
pool_mod='max'),
dict( dict(
type='PointNet2SASSG', type='PointNet2SASSG',
in_channels=4, in_channels=4,
...@@ -94,8 +91,7 @@ def test_multi_backbone(): ...@@ -94,8 +91,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'))
pool_mod='max')
]) ])
self = build_backbone(cfg_list) self = build_backbone(cfg_list)
...@@ -127,8 +123,7 @@ def test_multi_backbone(): ...@@ -127,8 +123,7 @@ def test_multi_backbone():
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)), (128, 128, 256)),
fp_channels=((256, 256), (256, 256)), fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d')))
pool_mod='max'))
self = build_backbone(cfg_dict) self = build_backbone(cfg_dict)
self.cuda() self.cuda()
...@@ -156,3 +151,68 @@ def test_multi_backbone(): ...@@ -156,3 +151,68 @@ def test_multi_backbone():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
cfg_dict['backbones'] = 'PointNet2SASSG' cfg_dict['backbones'] = 'PointNet2SASSG'
build_backbone(cfg_dict) build_backbone(cfg_dict)
def test_pointnet2_sa_msg():
if not torch.cuda.is_available():
pytest.skip()
cfg = dict(
type='PointNet2SAMSG',
in_channels=4,
num_points=(256, 64, (32, 32)),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((8, 8, 16), (8, 8, 16), (8, 8, 8)),
sa_channels=(((8, 8, 16), (8, 8, 16),
(8, 8, 16)), ((16, 16, 32), (16, 16, 32), (16, 24, 32)),
((32, 32, 64), (32, 24, 64), (32, 64, 64))),
aggregation_channels=(16, 32, 64),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (64, -1)),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False))
self = build_backbone(cfg)
self.cuda()
assert self.SA_modules[0].mlps[0].layer0.conv.in_channels == 4
assert self.SA_modules[0].mlps[0].layer0.conv.out_channels == 8
assert self.SA_modules[0].mlps[1].layer1.conv.out_channels == 8
assert self.SA_modules[2].mlps[2].layer2.conv.out_channels == 64
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', dtype=np.float32)
xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda() # (B, N, 6)
# test forward
ret_dict = self(xyz[:, :, :4])
sa_xyz = ret_dict['sa_xyz'][-1]
sa_features = ret_dict['sa_features'][-1]
sa_indices = ret_dict['sa_indices'][-1]
assert sa_xyz.shape == torch.Size([1, 64, 3])
assert sa_features.shape == torch.Size([1, 64, 64])
assert sa_indices.shape == torch.Size([1, 64])
# out_indices should smaller than the length of SA Modules.
with pytest.raises(AssertionError):
build_backbone(
dict(
type='PointNet2SAMSG',
in_channels=4,
num_points=(256, 64, (32, 32)),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((8, 8, 16), (8, 8, 16), (8, 8, 8)),
sa_channels=(((8, 8, 16), (8, 8, 16), (8, 8, 16)),
((16, 16, 32), (16, 16, 32), (16, 24, 32)),
((32, 32, 64), (32, 24, 64), (32, 64, 64))),
aggregation_channels=(16, 32, 64),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (64, -1)),
out_indices=(2, 3),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment