Unverified Commit 3e36ab83 authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

Feature h3dbackbone (#52)

* add multi backbone

* update multi backbone

* update multi backbone

* update multi backbone

* modify multi backbone

* modify multi_backbone params

* modify docstring

* update multi_backbone unittest

* modify docstring
parent 27d0001e
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_ssg import PointNet2SASSG from .pointnet2_sa_ssg import PointNet2SASSG
from .second import SECOND from .second import SECOND
__all__ = [ __all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'PointNet2SASSG' 'SECOND', 'PointNet2SASSG', 'MultiBackbone'
] ]
import copy
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone
@BACKBONES.register_module()
class MultiBackbone(nn.Module):
"""MultiBackbone with different configs.
Args:
num_streams (int): The number of backbones.
backbones (list or dict): A list of backbone configs.
aggregation_mlp_channels (list[int]): Specify the mlp layers
for feature aggregation.
conv_cfg (dict): Config dict of convolutional layers.
norm_cfg (dict): Config dict of normalization layers.
act_cfg (dict): Config dict of activation layers.
suffixes (list): A list of suffixes to rename the return dict
for each backbone.
"""
def __init__(self,
num_streams,
backbones,
aggregation_mlp_channels=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'),
**kwargs):
super().__init__()
assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict):
backbones_list = []
for ind in range(num_streams):
backbones_list.append(copy.deepcopy(backbones))
backbones = backbones_list
assert len(backbones) == num_streams
assert len(suffixes) == num_streams
self.backbone_list = nn.ModuleList()
# Rename the ret_dict with different suffixs.
self.suffixes = suffixes
out_channels = 0
for backbone_cfg in backbones:
out_channels += backbone_cfg['fp_channels'][-1][-1]
self.backbone_list.append(build_backbone(backbone_cfg))
# Feature aggregation layers
if aggregation_mlp_channels is None:
aggregation_mlp_channels = [
out_channels, out_channels // 2,
out_channels // len(self.backbone_list)
]
else:
aggregation_mlp_channels.insert(0, out_channels)
self.aggregation_layers = nn.Sequential()
for i in range(len(aggregation_mlp_channels) - 1):
self.aggregation_layers.add_module(
f'layer{i}',
ConvModule(
aggregation_mlp_channels[i],
aggregation_mlp_channels[i + 1],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=True))
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet++ backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
def forward(self, points):
"""Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
dict[str, list[torch.Tensor]]: Outputs from multiple backbones.
- fp_xyz[suffix] (list[torch.Tensor]): The coordinates of
each fp features.
- fp_features[suffix] (list[torch.Tensor]): The features
from each Feature Propagate Layers.
- fp_indices[suffix] (list[torch.Tensor]): Indices of the
input points.
- hd_feature (torch.Tensor): The aggregation feature
from multiple backbones.
"""
ret = {}
fp_features = []
for ind in range(len(self.backbone_list)):
cur_ret = self.backbone_list[ind](points)
cur_suffix = self.suffixes[ind]
fp_features.append(cur_ret['fp_features'][-1])
if cur_suffix != '':
for k in cur_ret.keys():
cur_ret[k + '_' + cur_suffix] = cur_ret.pop(k)
ret.update(cur_ret)
# Combine the features here
hd_feature = torch.cat(fp_features, dim=1)
hd_feature = self.aggregation_layers(hd_feature)
ret['hd_feature'] = hd_feature
return ret
...@@ -40,3 +40,119 @@ def test_pointnet2_sa_ssg(): ...@@ -40,3 +40,119 @@ def test_pointnet2_sa_ssg():
assert fp_xyz[2].shape == torch.Size([1, 100, 3]) assert fp_xyz[2].shape == torch.Size([1, 100, 3])
assert fp_features[2].shape == torch.Size([1, 16, 100]) assert fp_features[2].shape == torch.Size([1, 16, 100])
assert fp_indices[2].shape == torch.Size([1, 100]) assert fp_indices[2].shape == torch.Size([1, 100])
def test_multi_backbone():
if not torch.cuda.is_available():
pytest.skip()
# test list config
cfg_list = dict(
type='MultiBackbone',
num_streams=4,
suffixes=['net0', 'net1', 'net2', 'net3'],
backbones=[
dict(
type='PointNet2SASSG',
in_channels=4,
num_points=(256, 128, 64, 32),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
dict(
type='PointNet2SASSG',
in_channels=4,
num_points=(256, 128, 64, 32),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
dict(
type='PointNet2SASSG',
in_channels=4,
num_points=(256, 128, 64, 32),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
dict(
type='PointNet2SASSG',
in_channels=4,
num_points=(256, 128, 64, 32),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max')
])
self = build_backbone(cfg_list)
self.cuda()
assert len(self.backbone_list) == 4
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', dtype=np.float32)
xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda() # (B, N, 6)
# test forward
ret_dict = self(xyz[:, :, :4])
assert ret_dict['hd_feature'].shape == torch.Size([1, 256, 128])
assert ret_dict['fp_xyz_net0'][-1].shape == torch.Size([1, 128, 3])
assert ret_dict['fp_features_net0'][-1].shape == torch.Size([1, 256, 128])
# test dict config
cfg_dict = dict(
type='MultiBackbone',
num_streams=2,
suffixes=['net0', 'net1'],
aggregation_mlp_channels=[512, 128],
backbones=dict(
type='PointNet2SASSG',
in_channels=4,
num_points=(256, 128, 64, 32),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'))
self = build_backbone(cfg_dict)
self.cuda()
assert len(self.backbone_list) == 2
# test forward
ret_dict = self(xyz[:, :, :4])
assert ret_dict['hd_feature'].shape == torch.Size([1, 128, 128])
assert ret_dict['fp_xyz_net0'][-1].shape == torch.Size([1, 128, 3])
assert ret_dict['fp_features_net0'][-1].shape == torch.Size([1, 256, 128])
# Length of backbone configs list should be equal to num_streams
with pytest.raises(AssertionError):
cfg_list['num_streams'] = 3
build_backbone(cfg_list)
# Length of suffixes list should be equal to num_streams
with pytest.raises(AssertionError):
cfg_dict['suffixes'] = ['net0', 'net1', 'net2']
build_backbone(cfg_dict)
# Type of 'backbones' should be Dict or List[Dict].
with pytest.raises(AssertionError):
cfg_dict['backbones'] = 'PointNet2SASSG'
build_backbone(cfg_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment