Commit 5c0773f6 authored by Kiran Sanjeevan's avatar Kiran Sanjeevan Committed by Vincent QB
Browse files

torchaudio-contrib: some augmentations (#285)

* TimeStretch and Masking
* Doc stuff and naming
parent e3024341
......@@ -223,8 +223,7 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
@pytest.mark.parametrize('complex_specgrams', [
torch.randn(1, 2, 1025, 400, 2),
torch.randn(1, 1025, 400, 2)
torch.randn(2, 1025, 400, 2)
])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256])
......@@ -277,5 +276,45 @@ def test_complex_norm(complex_tensor, power):
assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
@pytest.mark.parametrize('specgram', [
torch.randn(2, 1025, 400),
torch.randn(1, 201, 100)
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [1, 2])
def test_mask_along_axis(specgram, mask_param, mask_value, axis):
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2
masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns /= mask_specgram.size(0)
assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param
@pytest.mark.parametrize('specgrams', [
torch.randn(4, 2, 1025, 400),
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [2, 3])
def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
other_axis = 2 if axis == 3 else 3
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
if __name__ == '__main__':
unittest.main()
import math
import torch
from . import functional as F
__all__ = [
'TimeStretch',
'FrequencyMasking',
'TimeMasking'
]
class TimeStretch(torch.jit.ScriptModule):
r"""Stretch stft in time without modifying pitch for a given rate.
Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
__constants__ = ['fixed_rate']
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__()
n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
self.fixed_rate = fixed_rate
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
@torch.jit.script_method
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, n_freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"
if overriding_rate is None:
rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else:
rate = overriding_rate
if rate == 1.0:
return complex_specgrams
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
class _AxisMasking(torch.jit.ScriptModule):
r"""
Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
def __init__(self, mask_param, axis, iid_masks):
super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks
@torch.jit.script_method
def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
"""
def __init__(self, freq_mask_param, iid_masks=False):
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False.
"""
def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
......@@ -18,6 +18,8 @@ __all__ = [
"lowpass_biquad",
"highpass_biquad",
"biquad",
'mask_along_axis',
'mask_along_axis_iid'
]
......@@ -228,8 +230,8 @@ def spectrogram(
normalized (bool): Whether to normalize by magnitude after stft
Returns:
torch.Tensor: Dimension (channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
torch.Tensor: Dimension (channel, n_freq, time), where channel
is unchanged, n_freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frames).
"""
assert waveform.dim() == 2
......@@ -397,7 +399,9 @@ def mu_law_decoding(x_mu, quantization_channels):
return x
@torch.jit.script
def complex_norm(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tensor
r"""Compute the norm of complex tensor input.
Args:
......@@ -439,64 +443,59 @@ def magphase(complex_tensor, power=1.0):
return mag, phase
@torch.jit.script
def phase_vocoder(complex_specgrams, rate, phase_advance):
# type: (Tensor, float, Tensor) -> Tensor
r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of ``rate``.
Args:
complex_specgrams (torch.Tensor): Dimension of `(*, channel, freq, time, complex=2)`
complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)`
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1)
Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(*, channel,
complex_specgrams_stretch (torch.Tensor): Dimension of `(channel,
freq, ceil(time/rate), complex=2)`
Example
>>> num_freqs, hop_length = 1025, 512
>>> # (batch, channel, num_freqs, time, complex=2)
>>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2)
>>> rate = 1.3 # Slow down by 30%
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, num_freqs)[..., None]
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([16, 1, 1025, 231, 2])
torch.Size([2, 1025, 231, 2])
"""
ndim = complex_specgrams.dim()
time_slice = [slice(None)] * (ndim - 2)
time_steps = torch.arange(
0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype,
)
time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype)
alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[time_slice + [slice(1)]])
phase_0 = angle(complex_specgrams[:, :, :1])
# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])
# (new_bins, num_freqs, 2)
complex_specgrams_0 = complex_specgrams[time_slice + [time_steps.long()]]
complex_specgrams_1 = complex_specgrams[time_slice + [(time_steps + 1).long()]]
# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()]
complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()]
angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)
norm_0 = torch.norm(complex_specgrams_0, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, dim=-1)
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)
phase = angle_1 - angle_0 - phase_advance
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))
# Compute Phase Accum
phase = phase + phase_advance
phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1)
phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1)
phase_acc = torch.cumsum(phase, -1)
mag = alphas * norm_1 + (1 - alphas) * norm_0
......@@ -662,6 +661,79 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
@torch.jit.script
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.
Args:
specgrams (Tensor): Real spectrograms (batch, channel, n_freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns:
torch.Tensor: Masked spectrograms of dimensions (batch, channel, n_freq, time)
"""
if axis != 2 and axis != 3:
raise ValueError('Only Frequency and Time masking are supported')
value = torch.rand(specgrams.shape[:2]) * mask_param
min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value)
# Create broadcastable mask
mask_start = (min_value.long())[..., None, None].float()
mask_end = (min_value.long() + value.long())[..., None, None].float()
mask = torch.arange(0, specgrams.size(axis)).float()
# Per batch example masking
specgrams = specgrams.transpose(axis, -1)
specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
specgrams = specgrams.transpose(axis, -1)
return specgrams
@torch.jit.script
def mask_along_axis(specgram, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.
Args:
specgram (Tensor): Real spectrogram (channel, n_freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (channel, n_freq, time)
"""
value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
mask_start = (min_value.long()).squeeze()
mask_end = (min_value.long() + value.long()).squeeze()
assert mask_end - mask_start < mask_param
if axis == 1:
specgram[:, mask_start:mask_end] = mask_value
elif axis == 2:
specgram[:, :, mask_start:mask_end] = mask_value
else:
raise ValueError('Only Frequency and Time masking are supported')
return specgram
@torch.jit.script
def compute_deltas(specgram, win_length=5, mode="replicate"):
# type: (Tensor, int, str) -> Tensor
r"""Compute delta coefficients of a tensor, usually a spectrogram:
......
......@@ -16,6 +16,7 @@ __all__ = [
'MuLawEncoding',
'MuLawDecoding',
'Resample',
'ComplexNorm'
]
......@@ -367,6 +368,28 @@ class Resample(torch.nn.Module):
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
class ComplexNorm(torch.jit.ScriptModule):
r"""Compute the norm of complex tensor input
Args:
power (float): Power of the norm. Defaults to `1.0`.
"""
__constants__ = ['power']
def __init__(self, power=1.0):
super(ComplexNorm, self).__init__()
self.power = power
@torch.jit.script_method
def forward(self, complex_tensor):
r"""
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
Returns:
Tensor: norm of the input tensor, shape of `(*, )`
"""
return F.complex_norm(complex_tensor, self.power)
class ComputeDeltas(torch.jit.ScriptModule):
r"""Compute delta coefficients of a tensor, usually a spectrogram.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment