Unverified Commit 99ed0521 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Move augmentations in transforms (#348)

* sync docs with functionals.

* Adding transforms to documentations. Moving augmentations in transforms.
parent a61b6472
...@@ -82,3 +82,28 @@ Functions to perform common audio operations. ...@@ -82,3 +82,28 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: highpass_biquad .. autofunction:: highpass_biquad
:hidden:`equalizer_biquad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: equalizer_biquad
:hidden:`mask_along_axis`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: mask_along_axis
:hidden:`mask_along_axis_iid`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: mask_along_axis_iid
:hidden:`compute_deltas`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: compute_deltas
:hidden:`detect_pitch_frequency`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: detect_pitch_frequency
...@@ -64,3 +64,38 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -64,3 +64,38 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: Resample .. autoclass:: Resample
.. automethod:: torchaudio._docs.Resample.forward .. automethod:: torchaudio._docs.Resample.forward
:hidden:`ComplexNorm`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ComplexNorm
.. automethod:: torchaudio._docs.ComplexNorm.forward
:hidden:`ComputeDeltas`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ComputeDeltas
.. automethod:: torchaudio._docs.ComputeDeltas.forward
:hidden:`TimeStretch`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TimeStretch
.. automethod:: torchaudio._docs.TimeStretch.forward
:hidden:`FrequencyMasking`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: FrequencyMasking
.. automethod:: torchaudio._docs.FrequencyMasking.forward
:hidden:`TimeMasking`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TimeMasking
.. automethod:: torchaudio._docs.TimeMasking.forward
...@@ -4,7 +4,6 @@ import os ...@@ -4,7 +4,6 @@ import os
import torch import torch
import torchaudio import torchaudio
import torchaudio.augmentations as A
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
...@@ -424,15 +423,15 @@ class Tester(unittest.TestCase): ...@@ -424,15 +423,15 @@ class Tester(unittest.TestCase):
hop_length = 512 hop_length = 512
fixed_rate = 1.3 fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2)) tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
def test_scriptmodule_FrequencyMasking(self): def test_scriptmodule_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2)) tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False) _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
def test_scriptmodule_TimeMasking(self): def test_scriptmodule_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2)) tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False) _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
if __name__ == '__main__': if __name__ == '__main__':
......
import math
import torch
from . import functional as F
__all__ = [
'TimeStretch',
'FrequencyMasking',
'TimeMasking'
]
class TimeStretch(torch.jit.ScriptModule):
r"""Stretch stft in time without modifying pitch for a given rate.
Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
__constants__ = ['fixed_rate']
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"
if overriding_rate is None:
rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else:
rate = overriding_rate
if rate == 1.0:
return complex_specgrams
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
class _AxisMasking(torch.nn.Module):
r"""
Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
def __init__(self, mask_param, axis, iid_masks):
super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks
def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
"""
def __init__(self, freq_mask_param, iid_masks=False):
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False.
"""
def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
...@@ -16,7 +16,10 @@ __all__ = [ ...@@ -16,7 +16,10 @@ __all__ = [
'MuLawEncoding', 'MuLawEncoding',
'MuLawDecoding', 'MuLawDecoding',
'Resample', 'Resample',
'ComplexNorm' 'ComplexNorm',
'TimeStretch',
'FrequencyMasking',
'TimeMasking',
] ]
...@@ -408,3 +411,121 @@ class ComputeDeltas(torch.nn.Module): ...@@ -408,3 +411,121 @@ class ComputeDeltas(torch.nn.Module):
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
""" """
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
class TimeStretch(torch.jit.ScriptModule):
r"""Stretch stft in time without modifying pitch for a given rate.
Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
__constants__ = ['fixed_rate']
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"
if overriding_rate is None:
rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else:
rate = overriding_rate
if rate == 1.0:
return complex_specgrams
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
class _AxisMasking(torch.nn.Module):
r"""
Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
def __init__(self, mask_param, axis, iid_masks):
super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks
def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
"""
def __init__(self, freq_mask_param, iid_masks=False):
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False.
"""
def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment