test_second_fpn.py 2.35 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmdet3d.models.builder import build_backbone, build_neck


def test_secfpn():
    neck_cfg = dict(
        type='SECONDFPN',
        in_channels=[2, 3],
        upsample_strides=[1, 2],
        out_channels=[4, 6],
    )
    from mmdet3d.models.builder import build_neck
    neck = build_neck(neck_cfg)
    assert neck.deblocks[0][0].in_channels == 2
    assert neck.deblocks[1][0].in_channels == 3
    assert neck.deblocks[0][0].out_channels == 4
    assert neck.deblocks[1][0].out_channels == 6
    assert neck.deblocks[0][0].stride == (1, 1)
    assert neck.deblocks[1][0].stride == (2, 2)
    assert neck is not None

    neck_cfg = dict(
        type='SECONDFPN',
        in_channels=[2, 2],
        upsample_strides=[1, 2, 4],
        out_channels=[2, 2],
    )
    with pytest.raises(AssertionError):
        build_neck(neck_cfg)

    neck_cfg = dict(
        type='SECONDFPN',
        in_channels=[2, 2, 4],
        upsample_strides=[1, 2, 4],
        out_channels=[2, 2],
    )
    with pytest.raises(AssertionError):
        build_neck(neck_cfg)


def test_centerpoint_fpn():

    second_cfg = dict(
        type='SECOND',
48
49
        in_channels=2,
        out_channels=[2, 2, 2],
VVsssssk's avatar
VVsssssk committed
50
51
52
53
54
55
56
57
58
59
        layer_nums=[3, 5, 5],
        layer_strides=[2, 2, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        conv_cfg=dict(type='Conv2d', bias=False))

    second = build_backbone(second_cfg)

    # centerpoint usage of fpn
    centerpoint_fpn_cfg = dict(
        type='SECONDFPN',
60
61
        in_channels=[2, 2, 2],
        out_channels=[2, 2, 2],
VVsssssk's avatar
VVsssssk committed
62
63
64
65
66
67
68
69
        upsample_strides=[0.5, 1, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        upsample_cfg=dict(type='deconv', bias=False),
        use_conv_for_no_stride=True)

    # original usage of fpn
    fpn_cfg = dict(
        type='SECONDFPN',
70
        in_channels=[2, 2, 2],
VVsssssk's avatar
VVsssssk committed
71
        upsample_strides=[1, 2, 4],
72
        out_channels=[2, 2, 2])
VVsssssk's avatar
VVsssssk committed
73
74
75
76
77

    second_fpn = build_neck(fpn_cfg)

    centerpoint_second_fpn = build_neck(centerpoint_fpn_cfg)

78
    input = torch.rand([2, 2, 32, 32])
VVsssssk's avatar
VVsssssk committed
79
80
81
    sec_output = second(input)
    centerpoint_output = centerpoint_second_fpn(sec_output)
    second_output = second_fpn(sec_output)
82
83
    assert centerpoint_output[0].shape == torch.Size([2, 6, 8, 8])
    assert second_output[0].shape == torch.Size([2, 6, 16, 16])