test_dla.py 751 Bytes
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
import torch

3
from mmdet3d.registry import MODELS
VVsssssk's avatar
VVsssssk committed
4
5
6
7
8
9
10
11
12
13
14
15


def test_dla_net():
    # test DLANet used in SMOKE
    # test list config
    cfg = dict(
        type='DLANet',
        depth=34,
        in_channels=3,
        norm_cfg=dict(type='GN', num_groups=32))

    img = torch.randn((4, 3, 32, 32))
16
    self = MODELS.build(cfg)
VVsssssk's avatar
VVsssssk committed
17
18
19
20
21
22
23
24
25
26
    self.init_weights()

    results = self(img)
    assert len(results) == 6
    assert results[0].shape == torch.Size([4, 16, 32, 32])
    assert results[1].shape == torch.Size([4, 32, 16, 16])
    assert results[2].shape == torch.Size([4, 64, 8, 8])
    assert results[3].shape == torch.Size([4, 128, 4, 4])
    assert results[4].shape == torch.Size([4, 256, 2, 2])
    assert results[5].shape == torch.Size([4, 512, 1, 1])