"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "30f305810702f47525a28e6a58d52414ecb79d0f"
Unverified Commit ec43b671 authored by yamengxi's avatar yamengxi Committed by GitHub
Browse files

[Enhance]enhance hsigmoid (#657)

* enhance hsigmoid

* delete int
parent 1290bdd1
......@@ -6,16 +6,28 @@ from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSigmoid(nn.Module):
"""Hard Sigmoid Module. Apply the hard sigmoid function:
Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
Args:
bias (float): Bias of the input feature map. Default: 1.0.
divisor (float): Divisor of the input feature map. Default: 2.0.
min_value (float): Lower bound value. Default: 0.0.
max_value (float): Upper bound value. Default: 1.0.
Returns:
Tensor: The output tensor.
"""
def __init__(self):
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
super(HSigmoid, self).__init__()
self.bias = bias
self.divisor = divisor
assert self.divisor != 0
self.min_value = min_value
self.max_value = max_value
def forward(self, x):
x = (x + 1) / 2
x = (x + self.bias) / self.divisor
return x.clamp_(0, 1)
return x.clamp_(self.min_value, self.max_value)
import pytest
import torch
from mmcv.cnn.bricks import HSigmoid
def test_hsigmoid():
# test assertion divisor can not be zero
with pytest.raises(AssertionError):
HSigmoid(divisor=0)
# test with default parameters
act = HSigmoid()
input_shape = torch.Size([1, 3, 64, 64])
input = torch.randn(input_shape)
......@@ -15,3 +21,16 @@ def test_hsigmoid():
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)
# test with designated parameters
act = HSigmoid(3, 6, 0, 1)
input_shape = torch.Size([1, 3, 64, 64])
input = torch.randn(input_shape)
output = act(input)
expected_output = torch.min(
torch.max((input + 3) / 6, torch.zeros(input_shape)),
torch.ones(input_shape))
# test output shape
assert output.shape == expected_output.shape
# test output value
assert torch.equal(output, expected_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment