Commit 9db4bdf1 authored by Cole Li's avatar Cole Li Committed by Facebook GitHub Bot
Browse files

Implement exp sigmoid (#3056)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3056

Task #2 from https://github.com/pytorch/audio/issues/2835

Reviewed By: mthrok

Differential Revision: D42854156

fbshipit-source-id: e1b3bd992c91fedc55f30a814e16efd7c51e0c80
parent a49edea5
......@@ -47,3 +47,10 @@ class AutogradTestImpl(TestBaseMixin):
waveform = torch.rand(3, 1, 2, 10, device=self.device, dtype=self.dtype, requires_grad=True)
filters = torch.rand(3, 2, device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.filter_waveform, (waveform, filters))
def test_exp_sigmoid_input(self):
input = torch.linspace(-5, 5, 20, device=self.device, dtype=self.dtype, requires_grad=True)
exponent = 10.0
max_value = 2.0
threshold = 1e-7
assert gradcheck(F.exp_sigmoid, (input, exponent, max_value, threshold))
......@@ -40,3 +40,26 @@ def freq_ir(magnitudes):
ir = np.fft.fftshift(np.fft.irfft(magnitudes), axes=-1)
window = np.hanning(ir.shape[-1])
return (ir * window).astype(magnitudes.dtype)
def exp_sigmoid(
input: np.ndarray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
) -> np.ndarray:
"""Exponential Sigmoid pointwise nonlinearity (Numpy version).
Implements the equation:
``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
The output has a range of [``threshold``, ``max_value``].
``exponent`` controls the slope of the output.
Args:
input (np.ndarray): Input array
exponent (float, optional): Exponent. Controls the slope of the output
max_value (float, optional): Maximum value of the output
threshold (float, optional): Minimum value of the output
Returns:
np.ndarray: Exponential Sigmoid output. Shape: same as input
"""
return max_value * (1 / (1 + np.exp(-input, dtype=input.dtype))) ** np.log(exponent, dtype=input.dtype) + threshold
......@@ -8,7 +8,12 @@ import torchaudio.prototype.functional as F
from parameterized import param, parameterized
from torchaudio_unittest.common_utils import nested_params, skipIfNoModule, skipIfNoRIR, TestBaseMixin
from .dsp_utils import freq_ir as freq_ir_np, oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
from .dsp_utils import (
exp_sigmoid as exp_sigmoid_np,
freq_ir as freq_ir_np,
oscillator_bank as oscillator_bank_np,
sinc_ir as sinc_ir_np,
)
def _prod(l):
......@@ -381,6 +386,37 @@ class FunctionalTestImpl(TestBaseMixin):
self.assertEqual(mix[-9:], ref2[-9:])
# the middle portion is where the two filters affect
@parameterized.expand(
[
# fmt: off
((-10, 10, 100), (10.0, 2.0, 1e-7)),
((-1, -1, 1), (5.0, 2.4, 1e-7)), # This is single sample
((0, 3, 10), (1, 1, 1e-12)),
# fmt: on
]
)
def test_exp_sigmoid_input_diff(self, linspace_input_values, exp_sigmoid_parameters):
"""Test exp_sigmoid function
linspace_input_values are tuples that specify (start, end, step) for torch.linspace
exp_sigmoid_parameters are parameters to exp_sigmoid function: (exponent, max_value, threshold)
"""
x = torch.linspace(
linspace_input_values[0],
linspace_input_values[1],
linspace_input_values[2],
dtype=self.dtype,
device=self.device,
)
exponent, max_value, threshold = exp_sigmoid_parameters
torch_out = F.exp_sigmoid(x, exponent, max_value, threshold)
np_out = exp_sigmoid_np(x.cpu().numpy(), exponent, max_value, threshold)
self.assertEqual(torch_out, torch.tensor(np_out))
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
......
from ._dsp import (
adsr_envelope,
exp_sigmoid,
extend_pitch,
filter_waveform,
frequency_impulse_response,
......@@ -12,6 +13,7 @@ from .functional import barkscale_fbanks
__all__ = [
"adsr_envelope",
"exp_sigmoid",
"barkscale_fbanks",
"extend_pitch",
"filter_waveform",
......
......@@ -404,3 +404,32 @@ def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compens
end = num_crops - start
result = restored[..., start:-end]
return result
def exp_sigmoid(
input: torch.Tensor, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
) -> torch.Tensor:
"""Exponential Sigmoid pointwise nonlinearity.
Implements the equation:
``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
The output has a range of [``threshold``, ``max_value``].
``exponent`` controls the slope of the output.
.. devices:: CPU CUDA
Args:
input (Tensor): Input Tensor
exponent (float, optional): Exponent. Controls the slope of the output
max_value (float, optional): Maximum value of the output
threshold (float, optional): Minimum value of the output
Returns:
Tensor: Exponential Sigmoid output. Shape: same as input
"""
return max_value * torch.pow(
torch.nn.functional.sigmoid(input),
torch.log(torch.tensor(exponent, device=input.device, dtype=input.dtype)),
) + torch.tensor(threshold, device=input.device, dtype=input.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment