test_second_fpn.py 2.28 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

5
from mmdet3d.registry import MODELS
VVsssssk's avatar
VVsssssk committed
6
7
8
9
10
11
12
13
14


def test_secfpn():
    neck_cfg = dict(
        type='SECONDFPN',
        in_channels=[2, 3],
        upsample_strides=[1, 2],
        out_channels=[4, 6],
    )
15
    neck = MODELS.build(neck_cfg)
VVsssssk's avatar
VVsssssk committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    assert neck.deblocks[0][0].in_channels == 2
    assert neck.deblocks[1][0].in_channels == 3
    assert neck.deblocks[0][0].out_channels == 4
    assert neck.deblocks[1][0].out_channels == 6
    assert neck.deblocks[0][0].stride == (1, 1)
    assert neck.deblocks[1][0].stride == (2, 2)
    assert neck is not None

    neck_cfg = dict(
        type='SECONDFPN',
        in_channels=[2, 2],
        upsample_strides=[1, 2, 4],
        out_channels=[2, 2],
    )
    with pytest.raises(AssertionError):
31
        MODELS.build(neck_cfg)
VVsssssk's avatar
VVsssssk committed
32
33
34
35
36
37
38
39

    neck_cfg = dict(
        type='SECONDFPN',
        in_channels=[2, 2, 4],
        upsample_strides=[1, 2, 4],
        out_channels=[2, 2],
    )
    with pytest.raises(AssertionError):
40
        MODELS.build(neck_cfg)
VVsssssk's avatar
VVsssssk committed
41
42
43
44
45
46


def test_centerpoint_fpn():

    second_cfg = dict(
        type='SECOND',
47
48
        in_channels=2,
        out_channels=[2, 2, 2],
VVsssssk's avatar
VVsssssk committed
49
50
51
52
53
        layer_nums=[3, 5, 5],
        layer_strides=[2, 2, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        conv_cfg=dict(type='Conv2d', bias=False))

54
    second = MODELS.build(second_cfg)
VVsssssk's avatar
VVsssssk committed
55
56
57
58

    # centerpoint usage of fpn
    centerpoint_fpn_cfg = dict(
        type='SECONDFPN',
59
60
        in_channels=[2, 2, 2],
        out_channels=[2, 2, 2],
VVsssssk's avatar
VVsssssk committed
61
62
63
64
65
66
67
68
        upsample_strides=[0.5, 1, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        upsample_cfg=dict(type='deconv', bias=False),
        use_conv_for_no_stride=True)

    # original usage of fpn
    fpn_cfg = dict(
        type='SECONDFPN',
69
        in_channels=[2, 2, 2],
VVsssssk's avatar
VVsssssk committed
70
        upsample_strides=[1, 2, 4],
71
        out_channels=[2, 2, 2])
VVsssssk's avatar
VVsssssk committed
72

73
    second_fpn = MODELS.build(fpn_cfg)
VVsssssk's avatar
VVsssssk committed
74

75
    centerpoint_second_fpn = MODELS.build(centerpoint_fpn_cfg)
VVsssssk's avatar
VVsssssk committed
76

77
    input = torch.rand([2, 2, 32, 32])
VVsssssk's avatar
VVsssssk committed
78
79
80
    sec_output = second(input)
    centerpoint_output = centerpoint_second_fpn(sec_output)
    second_output = second_fpn(sec_output)
81
82
    assert centerpoint_output[0].shape == torch.Size([2, 6, 8, 8])
    assert second_output[0].shape == torch.Size([2, 6, 16, 16])