Unverified Commit e7cb18c1 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add overdrive to functional (#569)



* Add overdrive to functional

* Minor change to overdrive

* Minor change to overdrive

* minor flake8 changes

* changes to make overdrive generic
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 954d5121
...@@ -133,6 +133,11 @@ Functions to perform common audio operations. ...@@ -133,6 +133,11 @@ Functions to perform common audio operations.
.. autofunction:: dcshift .. autofunction:: dcshift
:hidden:`overdrive`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: overdrive
:hidden:`mask_along_axis` :hidden:`mask_along_axis`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -156,4 +161,4 @@ Functions to perform common audio operations. ...@@ -156,4 +161,4 @@ Functions to perform common audio operations.
:hidden:`sliding_window_cmn` :hidden:`sliding_window_cmn`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sliding_window_cmn .. autofunction:: sliding_window_cmn
\ No newline at end of file
...@@ -72,6 +72,10 @@ class TestFunctional(unittest.TestCase): ...@@ -72,6 +72,10 @@ class TestFunctional(unittest.TestCase):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05) _test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
def test_overdrive(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.overdrive, waveform, gain=45, colour=30)
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5 waveform = torch.randn(2, 1024) - 0.5
_test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=True) _test_batch(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
......
...@@ -355,6 +355,25 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -355,6 +355,25 @@ class TestFunctionalFiltering(unittest.TestCase):
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5) torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_overdrive(self):
"""
Test overdrive effect, compare to SoX implementation
"""
gain = 30
colour = 40
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("overdrive", [gain, colour])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, _ = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.overdrive(waveform, gain, colour)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_equalizer(self): def test_equalizer(self):
......
...@@ -462,6 +462,17 @@ class _FunctionalTestMixin: ...@@ -462,6 +462,17 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_overdrive(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
gain = 30.
colour = 50.
return F.overdrive(tensor, gain, colour)
self._assert_consistency(func, waveform)
class _TransformsTestMixin: class _TransformsTestMixin:
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
......
...@@ -33,6 +33,7 @@ __all__ = [ ...@@ -33,6 +33,7 @@ __all__ = [
"biquad", "biquad",
"contrast", "contrast",
"dcshift", "dcshift",
"overdrive",
'mask_along_axis', 'mask_along_axis',
'mask_along_axis_iid', 'mask_along_axis_iid',
'sliding_window_cmn', 'sliding_window_cmn',
...@@ -1239,6 +1240,61 @@ def dcshift( ...@@ -1239,6 +1240,61 @@ def dcshift(
return output_waveform return output_waveform
def overdrive(
waveform: Tensor,
gain: float = 20,
colour: float = 20
) -> Tensor:
r"""Apply a overdrive effect to the audio. Similar to SoX implementation.
This effect applies a non linear distortion to the audio signal.
Args:
waveform (Tensor): audio waveform of dimension of `(..., time)`
gain (float): desired gain at the boost (or attenuation) in dB
Allowed range of values are 0 to 100
colour (float): controls the amount of even harmonic content in the over-driven output
Allowed range of values are 0 to 100
Returns:
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
"""
actual_shape = waveform.shape
device, dtype = waveform.device, waveform.dtype
# convert to 2D (..,time)
waveform = waveform.view(-1, actual_shape[-1])
gain = _dB2Linear(gain)
colour = colour / 200
last_in = torch.zeros(waveform.shape[:-1], dtype=dtype, device=device)
last_out = torch.zeros(waveform.shape[:-1], dtype=dtype, device=device)
temp = waveform * gain + colour
mask1 = temp < -1
temp[mask1] = torch.tensor(-2.0 / 3.0, dtype=dtype, device=device)
# Wrapping the constant with Tensor is required for Torchscript
mask2 = temp > 1
temp[mask2] = torch.tensor(2.0 / 3.0, dtype=dtype, device=device)
mask3 = (~mask1 & ~mask2)
temp[mask3] = temp[mask3] - (temp[mask3]**3) * (1. / 3)
output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
# TODO: Implement a torch CPP extension
for i in range(waveform.shape[-1]):
last_out = temp[:, i] - last_in + 0.995 * last_out
last_in = temp[:, i]
output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75
return output_waveform.clamp(min=-1, max=1).view(actual_shape)
def mask_along_axis_iid( def mask_along_axis_iid(
specgrams: Tensor, specgrams: Tensor,
mask_param: int, mask_param: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment