Commit d8ac46df authored by takuoko's avatar takuoko Committed by Zaida Zhou
Browse files

[Enhancement] Support SiLU with torch < 1.7.0

parent dfef1529
......@@ -15,7 +15,23 @@ for module in [
MODELS.register_module(module=module)
if digit_version(torch.__version__) >= digit_version('1.7.0'):
MODELS.register_module(module=nn.SiLU)
MODELS.register_module(module=nn.SiLU, name='SiLU')
else:
class SiLU(nn.Module):
"""Sigmoid Weighted Liner Unit."""
def __init__(self, inplace=False):
super().__init__()
self.inplace = inplace
def forward(self, inputs) -> torch.Tensor:
if self.inplace:
return inputs.mul_(torch.sigmoid(inputs))
else:
return inputs * torch.sigmoid(inputs)
MODELS.register_module(module=SiLU, name='SiLU')
@MODELS.register_module(name='Clip')
......
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F
from mmcv.cnn.bricks import build_activation_layer
from mmcv.utils import digit_version
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.7.0'),
reason='torch.nn.SiLU is not available before 1.7.0')
def test_silu():
act = build_activation_layer(dict(type='SiLU'))
input = torch.randn(1, 3, 64, 64)
expected_output = F.silu(input)
expected_output = input * torch.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)
assert torch.allclose(output, expected_output)
# test inplace
act = build_activation_layer(dict(type='SiLU', inplace=True))
assert act.inplace
input = torch.randn(1, 3, 64, 64)
expected_output = input * torch.sigmoid(input)
output = act(input)
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.allclose(output, expected_output)
assert torch.allclose(input, expected_output)
assert input is output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment