"vscode:/vscode.git/clone" did not exist on "d53e3ace6e4fe533177b8290ef191ede3e6cc2f6"
Unverified Commit 84efe00e authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feature] Add pointnet2 msg and refactor pointnets (#82)



* add op fps with distance

* add op fps with distance

* modify F-DFS unittest

* modify sa module

* modify sa module

* SA Module support D-FPS and F-FPS

* modify sa module

* update points sa module

* modify point_sa_module

* modify point sa module

* reconstruct FPS

* reconstruct FPS

* modify docstring

* add pointnet2-sa-msg backbone

* modify pointnet2_sa_msg

* fix merge conflicts

* format tests/test_backbones.py

* [Refactor]: Add registry for PointNet2Modules

* modify h3dnet for base pointnet

* fix docstring tweaks

* fix bugs for config unittest
Co-authored-by: default avatarZwwWayne <wayne.zw@outlook.com>
parent dde4b02c
......@@ -61,6 +61,8 @@ def test_config_build_detector():
head_config = config_mod.model['roi_head']
if head_config.type == 'PartAggregationROIHead':
check_parta2_roi_head(head_config, detector.roi_head)
elif head_config.type == 'H3DRoIHead':
check_h3d_roi_head(head_config, detector.roi_head)
else:
_check_roi_head(head_config, detector.roi_head)
# else:
......@@ -235,3 +237,41 @@ def _check_parta2_bbox_head(bbox_cfg, bbox_head):
assert bbox_cfg.seg_in_channels == bbox_head.seg_conv[0][0].in_channels
assert bbox_cfg.part_in_channels == bbox_head.part_conv[0][
0].in_channels
def check_h3d_roi_head(config, head):
assert config['type'] == head.__class__.__name__
# check seg_roi_extractor
primitive_z_cfg = config.primitive_list[0]
primitive_z_extractor = head.primitive_z
_check_primitive_extractor(primitive_z_cfg, primitive_z_extractor)
primitive_xy_cfg = config.primitive_list[1]
primitive_xy_extractor = head.primitive_xy
_check_primitive_extractor(primitive_xy_cfg, primitive_xy_extractor)
primitive_line_cfg = config.primitive_list[2]
primitive_line_extractor = head.primitive_line
_check_primitive_extractor(primitive_line_cfg, primitive_line_extractor)
# check bbox head infos
bbox_cfg = config.bbox_head
bbox_head = head.bbox_head
_check_h3d_bbox_head(bbox_cfg, bbox_head)
def _check_primitive_extractor(config, primitive_extractor):
assert config['type'] == primitive_extractor.__class__.__name__
assert (config.num_dims == primitive_extractor.num_dims)
assert (config.num_classes == primitive_extractor.num_classes)
def _check_h3d_bbox_head(bbox_cfg, bbox_head):
assert bbox_cfg['type'] == bbox_head.__class__.__name__
assert bbox_cfg.num_proposal * \
6 == bbox_head.surface_center_matcher.num_point[0]
assert bbox_cfg.num_proposal * \
12 == bbox_head.line_center_matcher.num_point[0]
assert bbox_cfg.suface_matching_cfg.mlp_channels[-1] * \
18 == bbox_head.bbox_pred[0].in_channels
......@@ -483,6 +483,7 @@ def test_primitive_head():
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=64,
radius=0.3,
num_sample=16,
......@@ -576,7 +577,7 @@ def test_h3d_head():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
h3d_head_cfg = _get_roi_head_cfg('h3dnet/h3dnet_8x8_scannet-3d-18class.py')
h3d_head_cfg = _get_roi_head_cfg('h3dnet/h3dnet_8x3_scannet-3d-18class.py')
self = build_head(h3d_head_cfg).cuda()
# prepare roi outputs
......@@ -585,7 +586,7 @@ def test_h3d_head():
fp_indices = [torch.randint(0, 128, [1, 1024]).cuda()]
aggregated_points = torch.rand([1, 256, 3], dtype=torch.float32).cuda()
aggregated_features = torch.rand([1, 128, 256], dtype=torch.float32).cuda()
rpn_proposals = torch.cat([
proposal_list = torch.cat([
torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4 - 2,
torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4,
torch.zeros([1, 256, 1]).cuda()
......@@ -599,7 +600,7 @@ def test_h3d_head():
aggregated_features=aggregated_features,
seed_points=fp_xyz[0],
seed_indices=fp_indices[0],
rpn_proposals=rpn_proposals)
proposal_list=proposal_list)
# prepare gt label
from mmdet3d.core.bbox import DepthInstance3DBoxes
......@@ -643,7 +644,6 @@ def test_h3d_head():
# train forward
ret_dict = self.forward_train(
input_dict,
'vote',
points=points,
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
......@@ -656,5 +656,5 @@ def test_h3d_head():
assert ret_dict['center_loss_z'] >= 0
assert ret_dict['size_loss_z'] >= 0
assert ret_dict['sem_loss_z'] >= 0
assert ret_dict['objectness_loss_opt'] >= 0
assert ret_dict['objectness_loss_optimized'] >= 0
assert ret_dict['primitive_sem_matching_loss'] >= 0
......@@ -16,7 +16,7 @@ def test_sparse_encoder():
128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1,
1)),
option='basicblock')
block_type='basicblock')
sparse_encoder = build_middle_encoder(sparse_encoder_cfg).cuda()
voxel_features = torch.rand([207842, 5]).cuda()
......
......@@ -137,16 +137,17 @@ def test_pointnet_sa_module_msg():
def test_pointnet_sa_module():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import PointSAModule
self = PointSAModule(
from mmdet3d.ops import build_sa_module
sa_cfg = dict(
type='PointSAModule',
num_point=16,
radius=0.2,
num_sample=8,
mlp_channels=[12, 32],
norm_cfg=dict(type='BN2d'),
use_xyz=True,
pool_mod='max').cuda()
pool_mod='max')
self = build_sa_module(sa_cfg).cuda()
assert self.mlps[0].layer0.conv.in_channels == 15
assert self.mlps[0].layer0.conv.out_channels == 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment