test_multi_backbone.py 4.12 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
5
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch

6
from mmdet3d.registry import MODELS
VVsssssk's avatar
VVsssssk committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


def test_multi_backbone():
    if not torch.cuda.is_available():
        pytest.skip()

    # test list config
    cfg_list = dict(
        type='MultiBackbone',
        num_streams=4,
        suffixes=['net0', 'net1', 'net2', 'net3'],
        backbones=[
            dict(
                type='PointNet2SASSG',
                in_channels=4,
                num_points=(256, 128, 64, 32),
                radius=(0.2, 0.4, 0.8, 1.2),
                num_samples=(64, 32, 16, 16),
                sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                             (128, 128, 256)),
                fp_channels=((256, 256), (256, 256)),
                norm_cfg=dict(type='BN2d')),
            dict(
                type='PointNet2SASSG',
                in_channels=4,
                num_points=(256, 128, 64, 32),
                radius=(0.2, 0.4, 0.8, 1.2),
                num_samples=(64, 32, 16, 16),
                sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                             (128, 128, 256)),
                fp_channels=((256, 256), (256, 256)),
                norm_cfg=dict(type='BN2d')),
            dict(
                type='PointNet2SASSG',
                in_channels=4,
                num_points=(256, 128, 64, 32),
                radius=(0.2, 0.4, 0.8, 1.2),
                num_samples=(64, 32, 16, 16),
                sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                             (128, 128, 256)),
                fp_channels=((256, 256), (256, 256)),
                norm_cfg=dict(type='BN2d')),
            dict(
                type='PointNet2SASSG',
                in_channels=4,
                num_points=(256, 128, 64, 32),
                radius=(0.2, 0.4, 0.8, 1.2),
                num_samples=(64, 32, 16, 16),
                sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                             (128, 128, 256)),
                fp_channels=((256, 256), (256, 256)),
                norm_cfg=dict(type='BN2d'))
        ])

61
    self = MODELS.build(cfg_list)
VVsssssk's avatar
VVsssssk committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    self.cuda()

    assert len(self.backbone_list) == 4

    xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', dtype=np.float32)
    xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda()  # (B, N, 6)
    # test forward
    ret_dict = self(xyz[:, :, :4])

    assert ret_dict['hd_feature'].shape == torch.Size([1, 256, 128])
    assert ret_dict['fp_xyz_net0'][-1].shape == torch.Size([1, 128, 3])
    assert ret_dict['fp_features_net0'][-1].shape == torch.Size([1, 256, 128])

    # test dict config
    cfg_dict = dict(
        type='MultiBackbone',
        num_streams=2,
        suffixes=['net0', 'net1'],
        aggregation_mlp_channels=[512, 128],
        backbones=dict(
            type='PointNet2SASSG',
            in_channels=4,
            num_points=(256, 128, 64, 32),
            radius=(0.2, 0.4, 0.8, 1.2),
            num_samples=(64, 32, 16, 16),
            sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
                         (128, 128, 256)),
            fp_channels=((256, 256), (256, 256)),
            norm_cfg=dict(type='BN2d')))

92
    self = MODELS.build(cfg_dict)
VVsssssk's avatar
VVsssssk committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    self.cuda()

    assert len(self.backbone_list) == 2

    # test forward
    ret_dict = self(xyz[:, :, :4])

    assert ret_dict['hd_feature'].shape == torch.Size([1, 128, 128])
    assert ret_dict['fp_xyz_net0'][-1].shape == torch.Size([1, 128, 3])
    assert ret_dict['fp_features_net0'][-1].shape == torch.Size([1, 256, 128])

    # Length of backbone configs list should be equal to num_streams
    with pytest.raises(AssertionError):
        cfg_list['num_streams'] = 3
107
        MODELS.build(cfg_list)
VVsssssk's avatar
VVsssssk committed
108
109
110
111

    # Length of suffixes list should be equal to num_streams
    with pytest.raises(AssertionError):
        cfg_dict['suffixes'] = ['net0', 'net1', 'net2']
112
        MODELS.build(cfg_dict)
VVsssssk's avatar
VVsssssk committed
113
114
115
116

    # Type of 'backbones' should be Dict or List[Dict].
    with pytest.raises(AssertionError):
        cfg_dict['backbones'] = 'PointNet2SASSG'
117
        MODELS.build(cfg_dict)