"...controlnet_pytorch.git" did not exist on "e2696ece051bc73e8c85fdf4d941704ee38817fd"
Unverified Commit 84efe00e authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feature] Add pointnet2 msg and refactor pointnets (#82)



* add op fps with distance

* add op fps with distance

* modify F-DFS unittest

* modify sa module

* modify sa module

* SA Module support D-FPS and F-FPS

* modify sa module

* update points sa module

* modify point_sa_module

* modify point sa module

* reconstruct FPS

* reconstruct FPS

* modify docstring

* add pointnet2-sa-msg backbone

* modify pointnet2_sa_msg

* fix merge conflicts

* format tests/test_backbones.py

* [Refactor]: Add registry for PointNet2Modules

* modify h3dnet for base pointnet

* fix docstring tweaks

* fix bugs for config unittest
Co-authored-by: default avatarZwwWayne <wayne.zw@outlook.com>
parent dde4b02c
...@@ -61,6 +61,8 @@ def test_config_build_detector(): ...@@ -61,6 +61,8 @@ def test_config_build_detector():
head_config = config_mod.model['roi_head'] head_config = config_mod.model['roi_head']
if head_config.type == 'PartAggregationROIHead': if head_config.type == 'PartAggregationROIHead':
check_parta2_roi_head(head_config, detector.roi_head) check_parta2_roi_head(head_config, detector.roi_head)
elif head_config.type == 'H3DRoIHead':
check_h3d_roi_head(head_config, detector.roi_head)
else: else:
_check_roi_head(head_config, detector.roi_head) _check_roi_head(head_config, detector.roi_head)
# else: # else:
...@@ -235,3 +237,41 @@ def _check_parta2_bbox_head(bbox_cfg, bbox_head): ...@@ -235,3 +237,41 @@ def _check_parta2_bbox_head(bbox_cfg, bbox_head):
assert bbox_cfg.seg_in_channels == bbox_head.seg_conv[0][0].in_channels assert bbox_cfg.seg_in_channels == bbox_head.seg_conv[0][0].in_channels
assert bbox_cfg.part_in_channels == bbox_head.part_conv[0][ assert bbox_cfg.part_in_channels == bbox_head.part_conv[0][
0].in_channels 0].in_channels
def check_h3d_roi_head(config, head):
assert config['type'] == head.__class__.__name__
# check seg_roi_extractor
primitive_z_cfg = config.primitive_list[0]
primitive_z_extractor = head.primitive_z
_check_primitive_extractor(primitive_z_cfg, primitive_z_extractor)
primitive_xy_cfg = config.primitive_list[1]
primitive_xy_extractor = head.primitive_xy
_check_primitive_extractor(primitive_xy_cfg, primitive_xy_extractor)
primitive_line_cfg = config.primitive_list[2]
primitive_line_extractor = head.primitive_line
_check_primitive_extractor(primitive_line_cfg, primitive_line_extractor)
# check bbox head infos
bbox_cfg = config.bbox_head
bbox_head = head.bbox_head
_check_h3d_bbox_head(bbox_cfg, bbox_head)
def _check_primitive_extractor(config, primitive_extractor):
assert config['type'] == primitive_extractor.__class__.__name__
assert (config.num_dims == primitive_extractor.num_dims)
assert (config.num_classes == primitive_extractor.num_classes)
def _check_h3d_bbox_head(bbox_cfg, bbox_head):
assert bbox_cfg['type'] == bbox_head.__class__.__name__
assert bbox_cfg.num_proposal * \
6 == bbox_head.surface_center_matcher.num_point[0]
assert bbox_cfg.num_proposal * \
12 == bbox_head.line_center_matcher.num_point[0]
assert bbox_cfg.suface_matching_cfg.mlp_channels[-1] * \
18 == bbox_head.bbox_pred[0].in_channels
...@@ -483,6 +483,7 @@ def test_primitive_head(): ...@@ -483,6 +483,7 @@ def test_primitive_head():
reduction='none', reduction='none',
loss_dst_weight=10.0)), loss_dst_weight=10.0)),
vote_aggregation_cfg=dict( vote_aggregation_cfg=dict(
type='PointSAModule',
num_point=64, num_point=64,
radius=0.3, radius=0.3,
num_sample=16, num_sample=16,
...@@ -576,7 +577,7 @@ def test_h3d_head(): ...@@ -576,7 +577,7 @@ def test_h3d_head():
pytest.skip('test requires GPU and torch+cuda') pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0) _setup_seed(0)
h3d_head_cfg = _get_roi_head_cfg('h3dnet/h3dnet_8x8_scannet-3d-18class.py') h3d_head_cfg = _get_roi_head_cfg('h3dnet/h3dnet_8x3_scannet-3d-18class.py')
self = build_head(h3d_head_cfg).cuda() self = build_head(h3d_head_cfg).cuda()
# prepare roi outputs # prepare roi outputs
...@@ -585,7 +586,7 @@ def test_h3d_head(): ...@@ -585,7 +586,7 @@ def test_h3d_head():
fp_indices = [torch.randint(0, 128, [1, 1024]).cuda()] fp_indices = [torch.randint(0, 128, [1, 1024]).cuda()]
aggregated_points = torch.rand([1, 256, 3], dtype=torch.float32).cuda() aggregated_points = torch.rand([1, 256, 3], dtype=torch.float32).cuda()
aggregated_features = torch.rand([1, 128, 256], dtype=torch.float32).cuda() aggregated_features = torch.rand([1, 128, 256], dtype=torch.float32).cuda()
rpn_proposals = torch.cat([ proposal_list = torch.cat([
torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4 - 2, torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4 - 2,
torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4, torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4,
torch.zeros([1, 256, 1]).cuda() torch.zeros([1, 256, 1]).cuda()
...@@ -599,7 +600,7 @@ def test_h3d_head(): ...@@ -599,7 +600,7 @@ def test_h3d_head():
aggregated_features=aggregated_features, aggregated_features=aggregated_features,
seed_points=fp_xyz[0], seed_points=fp_xyz[0],
seed_indices=fp_indices[0], seed_indices=fp_indices[0],
rpn_proposals=rpn_proposals) proposal_list=proposal_list)
# prepare gt label # prepare gt label
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
...@@ -643,7 +644,6 @@ def test_h3d_head(): ...@@ -643,7 +644,6 @@ def test_h3d_head():
# train forward # train forward
ret_dict = self.forward_train( ret_dict = self.forward_train(
input_dict, input_dict,
'vote',
points=points, points=points,
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d, gt_labels_3d=gt_labels_3d,
...@@ -656,5 +656,5 @@ def test_h3d_head(): ...@@ -656,5 +656,5 @@ def test_h3d_head():
assert ret_dict['center_loss_z'] >= 0 assert ret_dict['center_loss_z'] >= 0
assert ret_dict['size_loss_z'] >= 0 assert ret_dict['size_loss_z'] >= 0
assert ret_dict['sem_loss_z'] >= 0 assert ret_dict['sem_loss_z'] >= 0
assert ret_dict['objectness_loss_opt'] >= 0 assert ret_dict['objectness_loss_optimized'] >= 0
assert ret_dict['primitive_sem_matching_loss'] >= 0 assert ret_dict['primitive_sem_matching_loss'] >= 0
...@@ -16,7 +16,7 @@ def test_sparse_encoder(): ...@@ -16,7 +16,7 @@ def test_sparse_encoder():
128)), 128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1,
1)), 1)),
option='basicblock') block_type='basicblock')
sparse_encoder = build_middle_encoder(sparse_encoder_cfg).cuda() sparse_encoder = build_middle_encoder(sparse_encoder_cfg).cuda()
voxel_features = torch.rand([207842, 5]).cuda() voxel_features = torch.rand([207842, 5]).cuda()
......
...@@ -137,16 +137,17 @@ def test_pointnet_sa_module_msg(): ...@@ -137,16 +137,17 @@ def test_pointnet_sa_module_msg():
def test_pointnet_sa_module(): def test_pointnet_sa_module():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip() pytest.skip()
from mmdet3d.ops import PointSAModule from mmdet3d.ops import build_sa_module
sa_cfg = dict(
self = PointSAModule( type='PointSAModule',
num_point=16, num_point=16,
radius=0.2, radius=0.2,
num_sample=8, num_sample=8,
mlp_channels=[12, 32], mlp_channels=[12, 32],
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
use_xyz=True, use_xyz=True,
pool_mod='max').cuda() pool_mod='max')
self = build_sa_module(sa_cfg).cuda()
assert self.mlps[0].layer0.conv.in_channels == 15 assert self.mlps[0].layer0.conv.in_channels == 15
assert self.mlps[0].layer0.conv.out_channels == 32 assert self.mlps[0].layer0.conv.out_channels == 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment