test_swish.py 378 Bytes
Newer Older
Xiaojie Li's avatar
Xiaojie Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch.nn.functional import sigmoid

from mmcv.cnn.bricks import Swish


def test_swish():
    act = Swish()
    input = torch.randn(1, 3, 64, 64)
    expected_output = input * sigmoid(input)
    output = act(input)
    # test output shape
    assert output.shape == expected_output.shape
    # test output value
    assert torch.equal(output, expected_output)