"...parallel-hashmap/benchmark/js/jquery.flot.browser.js" did not exist on "e8309f27790ee465f7567a5e6800b30cfe31686d"
test_silu.py 892 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmcv.cnn.bricks import build_activation_layer


def test_silu():
    act = build_activation_layer(dict(type='SiLU'))
    input = torch.randn(1, 3, 64, 64)
10
    expected_output = input * torch.sigmoid(input)
11
12
13
14
    output = act(input)
    # test output shape
    assert output.shape == expected_output.shape
    # test output value
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    assert torch.allclose(output, expected_output)

    # test inplace
    act = build_activation_layer(dict(type='SiLU', inplace=True))
    assert act.inplace
    input = torch.randn(1, 3, 64, 64)
    expected_output = input * torch.sigmoid(input)
    output = act(input)
    # test output shape
    assert output.shape == expected_output.shape
    # test output value
    assert torch.allclose(output, expected_output)
    assert torch.allclose(input, expected_output)
    assert input is output