Unverified Commit 91e59231 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add dcshift to functional (#558)

* Add dcshift to functional

* Doc string change and remove inplace clamp

* Minor Fix to dcshit and separate sox test refactoring

* Minor change to limiter_gain type

* adding dcshift to __all__ in functional
parent fc2537e7
......@@ -128,6 +128,11 @@ Functions to perform common audio operations.
.. autofunction:: contrast
:hidden:`dcshift`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: dcshift
:hidden:`mask_along_axis`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -68,6 +68,10 @@ class TestFunctional(unittest.TestCase):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.contrast, waveform, enhancement_amount=80.)
def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5
_test_batch(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module"""
......
......@@ -318,6 +318,43 @@ class TestFunctionalFiltering(unittest.TestCase):
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dcshift_with_limiter(self):
"""
Test dcshift effect, compare to SoX implementation
"""
shift = 0.5
limiter_gain = 0.05
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("dcshift", [shift, limiter_gain])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, _ = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.dcshift(waveform, shift, limiter_gain)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dcshift_without_limiter(self):
"""
Test dcshift effect, compare to SoX implementation
"""
shift = 0.6
noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("dcshift", [shift])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, _ = torchaudio.load(noise_filepath, normalization=True)
output_waveform = F.dcshift(waveform, shift)
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_equalizer(self):
......
......@@ -451,6 +451,18 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_dcshift(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
shift = 0.5
limiter_gain = 0.05
return F.dcshift(tensor, shift, limiter_gain)
self._assert_consistency(func, waveform)
class _TransformsTestMixin:
"""Implements test for Transforms that are performed for different devices"""
device = None
......
......@@ -32,6 +32,7 @@ __all__ = [
"riaa_biquad",
"biquad",
"contrast",
"dcshift",
'mask_along_axis',
'mask_along_axis_iid',
'sliding_window_cmn',
......@@ -1194,6 +1195,50 @@ def contrast(
return output_waveform
def dcshift(
waveform: Tensor,
shift: float,
limiter_gain: Optional[float] = None
) -> Tensor:
r"""Apply a DC shift to the audio. Similar to SoX implementation.
This can be useful to remove a DC offset
(caused perhaps by a hardware problem in the recording chain) from the audio
Args:
waveform (Tensor): audio waveform of dimension of `(..., time)`
shift (float): indicates the amount to shift the audio
Allowed range of values for shift : -2.0 to +2.0
limiter_gain (float): It is used only on peaks to prevent clipping
It should have a value much less than 1 (e.g. 0.05 or 0.02)
Returns:
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
"""
output_waveform = waveform
limiter_threshold = 0.
if limiter_gain is not None:
limiter_threshold = 1.0 - (abs(shift) - limiter_gain)
if limiter_gain is not None and shift > 0:
mask = waveform > limiter_threshold
temp = (waveform[mask] - limiter_threshold) * limiter_gain / (1 - limiter_threshold)
output_waveform[mask] = (temp + limiter_threshold + shift).clamp(max=limiter_threshold)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
elif limiter_gain is not None and shift < 0:
mask = waveform < -limiter_threshold
temp = (waveform[mask] + limiter_threshold) * limiter_gain / (1 - limiter_threshold)
output_waveform[mask] = (temp - limiter_threshold + shift).clamp(min=-limiter_threshold)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
else:
output_waveform = (waveform + shift).clamp(min=-1, max=1)
return output_waveform
def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment