test_hswish.py 514 Bytes
Newer Older
louzana's avatar
louzana committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch.nn.functional import relu6

from mmcv.cnn.bricks import HSwish


def test_hswish():
    # test inplace
    act = HSwish(inplace=True)
    assert act.act.inplace
    act = HSwish()
    assert not act.act.inplace

    input = torch.randn(1, 3, 64, 64)
    expected_output = input * relu6(input + 3) / 6
    output = act(input)
    # test output shape
    assert output.shape == expected_output.shape
    # test output value
    assert torch.equal(output, expected_output)