Unverified Commit 11fb22aa authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Add Vol Transformation (#468)

* Add Vol with gain_type amplitude

* add gain in db and add tests

* add gain_type "power" and tests

* add functional DB_to_amplitude

* simplify

* remove functional

* improve docstring

* add to documentation
parent fcc9a51a
...@@ -144,6 +144,7 @@ Transforms expect and return the following dimensions. ...@@ -144,6 +144,7 @@ Transforms expect and return the following dimensions.
* `MuLawDecode`: (channel, time) -> (channel, time) * `MuLawDecode`: (channel, time) -> (channel, time)
* `Resample`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time)
* `Fade`: (channel, time) -> (channel, time) * `Fade`: (channel, time) -> (channel, time)
* `Vol`: (channel, time) -> (channel, time)
Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions. Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions.
......
...@@ -121,3 +121,10 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -121,3 +121,10 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: TimeMasking .. autoclass:: TimeMasking
.. automethod:: forward .. automethod:: forward
:hidden:`Vol`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Vol
.. automethod:: forward
...@@ -244,6 +244,20 @@ class Test_SoxEffectsChain(unittest.TestCase): ...@@ -244,6 +244,20 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked # check if effect worked
self.assertTrue(x.allclose(fade(x_orig), rtol=1e-4, atol=1e-4)) self.assertTrue(x.allclose(fade(x_orig), rtol=1e-4, atol=1e-4))
def test_vol(self):
x_orig, _ = torchaudio.load(self.test_filepath)
for gain, gain_type in ((1.1, "amplitude"), (2, "db"), (2, "power")):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("vol", [gain, gain_type])
x, sr = E.sox_build_flow_effects()
vol = torchaudio.transforms.Vol(gain, gain_type)
z = vol(x_orig)
# check if effect worked
self.assertTrue(x.allclose(vol(x_orig), rtol=1e-4, atol=1e-4))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -609,6 +609,23 @@ class Tester(unittest.TestCase): ...@@ -609,6 +609,23 @@ class Tester(unittest.TestCase):
tensor = torch.rand((10, 2, 50, 10, 2)) tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False) _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
def test_scriptmodule_Vol(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
_test_script_module(transforms.Vol, waveform, 1.1)
def test_batch_Vol(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
class TestLibrosaConsistency(unittest.TestCase): class TestLibrosaConsistency(unittest.TestCase):
test_dirpath = None test_dirpath = None
......
...@@ -774,3 +774,43 @@ class TimeMasking(_AxisMasking): ...@@ -774,3 +774,43 @@ class TimeMasking(_AxisMasking):
def __init__(self, time_mask_param, iid_masks=False): def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
class Vol(torch.nn.Module):
r"""Add a volume to an waveform.
Args:
gain (float): Interpreted according to the given gain_type:
If `gain_type’ = ‘amplitude’, `gain’ is a positive amplitude ratio.
If `gain_type’ = ‘power’, `gain’ is a power (voltage squared).
If `gain_type’ = ‘db’, `gain’ is in decibels.
gain_type (str, optional): Type of gain. One of: ‘amplitude’, ‘power’, ‘db’ (Default: ``"amplitude"``)
"""
def __init__(self, gain, gain_type='amplitude'):
super(Vol, self).__init__()
self.gain = gain
self.gain_type = gain_type
if gain_type in ['amplitude', 'power'] and gain < 0:
raise ValueError("If gain_type = amplitude or power, gain must be positive.")
def forward(self, waveform):
# type: (Tensor) -> Tensor
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Tensor of audio of dimension (..., time).
"""
if self.gain_type == "amplitude":
waveform = waveform * self.gain
if self.gain_type == "db":
waveform = F.gain(waveform, self.gain)
if self.gain_type == "power":
waveform = F.gain(waveform, 10 * math.log10(self.gain))
return torch.clamp(waveform, -1, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment