Commit 82febc59 authored by Xiaohui Zhang's avatar Xiaohui Zhang Committed by Facebook GitHub Bot
Browse files

Add SpecAugment transform (#3309)

Summary:
(2/2 of the previous https://github.com/pytorch/audio/pull/2360 which I accidentally closed)

The previous way of doing SpecAugment via Frequency/TimeMasking transforms has the following problems:
- Only zero masking can be done; masking by mean value is not supported.
- mask_along_axis is hard-coded to mask the 1st dimension and mask_along_axis_iid is hard-code to mask the 2nd or 3rd dimension of the input tensor.
- For 3D spectrogram tensors where the first dimension is batch or channel, features from the same batch or different channels have to use the same mask, because mask_along_axis_iid only support 4D tensors, because of the above hard-coding
- For 2D spectrogram tensors w/o a batch or channel dimension, Time/Frequency masking can't be applied at all, since mask_along_axis only support 3D tensors, because of the above hard-coding.
- It's not straightforward to apply multiple time/frequency masks by the current design. If we need N masks across time/frequency axis, we need to sequentially apply N Frequency/TimeMasking transforms to input tensors, and such API looks very inconvenient. We need to introduce a separate SpecAugment transform to handle this.

To solve these issues, here we
[done in the previous [PR](https://github.com/pytorch/audio/pull/3289)] Extend mask_along_axis_iid to support 3D+ tensors and mask_along_axis to support 2D+ tensors. Now both of them are able to mask one of the last two dimensions (where the time or frequency dimension lives) of the input tensor.
[done in this PR] Introducing SpecAugment transform.

Pull Request resolved: https://github.com/pytorch/audio/pull/3309

Reviewed By: nateanl

Differential Revision: D45592926

Pulled By: xiaohui-zhang

fbshipit-source-id: 97cd686dbb6c1c6ff604716b71a876e616aaf1a2
parent 1e3af12f
...@@ -410,3 +410,70 @@ class TransformsTestBase(TestBaseMixin): ...@@ -410,3 +410,70 @@ class TransformsTestBase(TestBaseMixin):
unmasked_axis_mean = torch.mean(masked, axis) unmasked_axis_mean = torch.mean(masked, axis)
self.assertTrue(0 in unmasked_axis_mean) self.assertTrue(0 in unmasked_axis_mean)
self.assertFalse(False in torch.eq(unmasked_axis_mean[unmasked_axis_mean != 0], 1)) self.assertFalse(False in torch.eq(unmasked_axis_mean[unmasked_axis_mean != 0], 1))
@parameterized.expand(
[
param(10, 20, 10, 20, False),
param(0, 20, 10, 20, False),
param(10, 20, 0, 20, False),
param(10, 20, 10, 20, True),
param(0, 20, 10, 20, True),
param(10, 20, 0, 20, True),
]
)
def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mask_param, iid_masks):
"""Make sure SpecAug masking works as expected"""
spec = torch.ones(2, 200, 200)
transform = T.SpecAugment(
n_time_masks=n_time_masks,
time_mask_param=time_mask_param,
n_freq_masks=n_freq_masks,
freq_mask_param=freq_mask_param,
iid_masks=iid_masks,
zero_masking=True,
)
spec_masked = transform(spec)
f_axis_mean = torch.mean(spec_masked, 1)
t_axis_mean = torch.mean(spec_masked, 2)
if n_time_masks == 0 and n_freq_masks == 0:
self.assertEqual(spec, spec_masked)
elif n_time_masks > 0 and n_freq_masks > 0:
# Across both time and frequency dimensions, the mean tensor should contain
# at least one zero element, and all non-zero elements should be less than 1.
self.assertTrue(0 in t_axis_mean)
self.assertFalse(False in torch.lt(t_axis_mean[t_axis_mean != 0], 1))
self.assertTrue(0 in f_axis_mean)
self.assertFalse(False in torch.lt(f_axis_mean[f_axis_mean != 0], 1))
elif n_freq_masks > 0:
# Across the frequency axis where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
self.assertFalse(False in torch.eq(f_axis_mean[0], f_axis_mean[0][0]))
self.assertFalse(False in torch.eq(f_axis_mean[1], f_axis_mean[1][0]))
self.assertTrue(f_axis_mean[0][0] < 1)
self.assertTrue(f_axis_mean[1][0] > 0)
# Across the time axis where we don't mask, the mean tensor should contain at
# least one zero element, and all non-zero elements should be 1.
self.assertTrue(0 in t_axis_mean)
self.assertFalse(False in torch.eq(t_axis_mean[t_axis_mean != 0], 1))
else:
# Across the time axis where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
self.assertFalse(False in torch.eq(t_axis_mean[0], t_axis_mean[0][0]))
self.assertFalse(False in torch.eq(t_axis_mean[1], t_axis_mean[1][0]))
self.assertTrue(t_axis_mean[0][0] < 1)
self.assertTrue(t_axis_mean[1][0] > 0)
# Across the frequency axis where we don't mask, the mean tensor should contain at
# least one zero element, and all non-zero elements should be 1.
self.assertTrue(0 in f_axis_mean)
self.assertFalse(False in torch.eq(f_axis_mean[f_axis_mean != 0], 1))
# Test if iid_masks gives different masking results for different spectrograms across the 0th dimension.
print(torch.norm(spec_masked[0] - spec_masked[1]).item())
if iid_masks is True:
self.assertTrue(torch.norm(spec_masked[0] - spec_masked[1]).item() > 0)
else:
self.assertTrue(torch.norm(spec_masked[0] - spec_masked[1]).item() == 0)
...@@ -23,6 +23,7 @@ from ._transforms import ( ...@@ -23,6 +23,7 @@ from ._transforms import (
Resample, Resample,
RNNTLoss, RNNTLoss,
SlidingWindowCmn, SlidingWindowCmn,
SpecAugment,
SpectralCentroid, SpectralCentroid,
Spectrogram, Spectrogram,
Speed, Speed,
...@@ -62,6 +63,7 @@ __all__ = [ ...@@ -62,6 +63,7 @@ __all__ = [
"Resample", "Resample",
"SlidingWindowCmn", "SlidingWindowCmn",
"SoudenMVDR", "SoudenMVDR",
"SpecAugment",
"SpectralCentroid", "SpectralCentroid",
"Spectrogram", "Spectrogram",
"Speed", "Speed",
......
...@@ -1304,6 +1304,77 @@ class TimeMasking(_AxisMasking): ...@@ -1304,6 +1304,77 @@ class TimeMasking(_AxisMasking):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p) super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)
class SpecAugment(torch.nn.Module):
r"""Apply time and frequency masking to a spectrogram.
Args:
n_time_masks (int): Number of time masks. If its value is zero, no time masking will be applied.
time_mask_param (int): Maximum possible length of the time mask.
n_freq_masks (int): Number of frequency masks. If its value is zero, no frequency masking will be applied.
freq_mask_param (int): Maximum possible length of the frequency mask.
iid_masks (bool, optional): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D. (Default: ``True``)
p (float, optional): maximum proportion of time steps that can be masked.
Must be within range [0.0, 1.0]. (Default: 1.0)
zero_masking (bool, optional): If ``True``, use 0 as the mask value with 0,
else use mean of the input tensor. (Default: ``False``)
"""
__constants__ = [
"n_time_masks",
"time_mask_param",
"n_freq_masks",
"freq_mask_param",
"iid_masks",
"p",
"zero_masking",
]
def __init__(
self,
n_time_masks: int,
time_mask_param: int,
n_freq_masks: int,
freq_mask_param: int,
iid_masks: bool = True,
p: float = 1.0,
zero_masking: bool = False,
) -> None:
super(SpecAugment, self).__init__()
self.n_time_masks = n_time_masks
self.time_mask_param = time_mask_param
self.n_freq_masks = n_freq_masks
self.freq_mask_param = freq_mask_param
self.iid_masks = iid_masks
self.p = p
self.zero_masking = zero_masking
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor): Tensor of shape `(..., freq, time)`.
Returns:
Tensor: Masked spectrogram of shape `(..., freq, time)`.
"""
if self.zero_masking:
mask_value = 0.0
else:
mask_value = specgram.mean()
time_dim = specgram.dim() - 1
freq_dim = time_dim - 1
if specgram.dim() > 2 and self.iid_masks is True:
for _ in range(self.n_time_masks):
specgram = F.mask_along_axis_iid(specgram, self.time_mask_param, mask_value, time_dim, p=self.p)
for _ in range(self.n_freq_masks):
specgram = F.mask_along_axis_iid(specgram, self.freq_mask_param, mask_value, freq_dim, p=self.p)
else:
for _ in range(self.n_time_masks):
specgram = F.mask_along_axis(specgram, self.time_mask_param, mask_value, time_dim, p=self.p)
for _ in range(self.n_freq_masks):
specgram = F.mask_along_axis(specgram, self.freq_mask_param, mask_value, freq_dim, p=self.p)
return specgram
class Loudness(torch.nn.Module): class Loudness(torch.nn.Module):
r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation. r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment