test_scale.py 2.3 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import pytest
Kai Chen's avatar
Kai Chen committed
3
4
import torch

5
from mmcv.cnn.bricks import LayerScale, Scale
Kai Chen's avatar
Kai Chen committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23


def test_scale():
    # test default scale
    scale = Scale()
    assert scale.scale.data == 1.
    assert scale.scale.dtype == torch.float
    x = torch.rand(1, 3, 64, 64)
    output = scale(x)
    assert output.shape == (1, 3, 64, 64)

    # test given scale
    scale = Scale(10.)
    assert scale.scale.data == 10.
    assert scale.scale.dtype == torch.float
    x = torch.rand(1, 3, 64, 64)
    output = scale(x)
    assert output.shape == (1, 3, 64, 64)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78


def test_layer_scale():
    with pytest.raises(AssertionError):
        cfg = dict(
            dim=10,
            data_format='BNC',
        )
        LayerScale(**cfg)

    # test init
    cfg = dict(dim=10)
    ls = LayerScale(**cfg)
    assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)

    # test forward
    # test channels_last
    cfg = dict(dim=256, inplace=False, data_format='channels_last')
    ls_channels_last = LayerScale(**cfg)
    x = torch.randn((4, 49, 256))
    out = ls_channels_last(x)
    assert tuple(out.size()) == (4, 49, 256)
    assert torch.equal(x * 1e-5, out)

    # test channels_last 2d
    cfg = dict(dim=256, inplace=False, data_format='channels_last')
    ls_channels_last = LayerScale(**cfg)
    x = torch.randn((4, 7, 49, 256))
    out = ls_channels_last(x)
    assert tuple(out.size()) == (4, 7, 49, 256)
    assert torch.equal(x * 1e-5, out)

    # test channels_first
    cfg = dict(dim=256, inplace=False, data_format='channels_first')
    ls_channels_first = LayerScale(**cfg)
    x = torch.randn((4, 256, 7, 7))
    out = ls_channels_first(x)
    assert tuple(out.size()) == (4, 256, 7, 7)
    assert torch.equal(x * 1e-5, out)

    # test channels_first 3D
    cfg = dict(dim=256, inplace=False, data_format='channels_first')
    ls_channels_first = LayerScale(**cfg)
    x = torch.randn((4, 256, 7, 7, 7))
    out = ls_channels_first(x)
    assert tuple(out.size()) == (4, 256, 7, 7, 7)
    assert torch.equal(x * 1e-5, out)

    # test inplace True
    cfg = dict(dim=256, inplace=True, data_format='channels_first')
    ls_channels_first = LayerScale(**cfg)
    x = torch.randn((4, 256, 7, 7))
    out = ls_channels_first(x)
    assert tuple(out.size()) == (4, 256, 7, 7)
    assert x is out