Commit 74bd971a authored by Xiaohui Zhang's avatar Xiaohui Zhang Committed by Facebook GitHub Bot
Browse files

Extend mask_along_axis{,_iid} (#3289)

Summary:
(1/2 of the previous [PR](https://github.com/pytorch/audio/pull/2360) which I accidentally closed)

The previous way of doing SpecAugment via Frequency/TimeMasking transforms has the following problems:
- Only zero masking can be done; masking by mean value is not supported.
- mask_along_axis is hard-coded to mask the 1st dimension and mask_along_axis_iid is hard-code to mask the 2nd or 3rd dimension of the input tensor.
- For 3D spectrogram tensors where the first dimension is batch or channel, features from the same batch or different channels have to use the same mask, because mask_along_axis_iid only support 4D tensors, because of the above hard-coding
- For 2D spectrogram tensors w/o a batch or channel dimension, Time/Frequency masking can't be applied at all, since mask_along_axis only support 3D tensors, because of the above hard-coding.
- It's not straightforward to apply multiple time/frequency masks by the current design.

To solve these issues, here we
- Extend mask_along_axis_iid to support 3D tensors and mask_along_axis to support 2D tensors. Now both of them are able to mask one of the last two dimensions (where the time or frequency dimension lives) of the input tensor.

The introduction of SpecAugment transform will be done in another PR.

Pull Request resolved: https://github.com/pytorch/audio/pull/3289

Reviewed By: hwangjeff

Differential Revision: D45460357

Pulled By: xiaohui-zhang

fbshipit-source-id: 91bf448294799f13789d96a13d4bae2451461ef3
parent c51f20f9
......@@ -397,22 +397,38 @@ class Functional(TestBaseMixin):
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}"
@parameterized.expand(list(itertools.product([(1, 201, 100), (10, 2, 201, 300)])))
def test_mask_along_axis_input_axis_check(self, shape):
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
message = "Only Frequency and Time masking are supported"
with self.assertRaisesRegex(ValueError, message):
F.mask_along_axis(specgram, 100, 0.0, 0, 1.0)
@parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2], [0.33, 1.0]))
list(
itertools.product([(1025, 400), (1, 201, 100), (10, 2, 201, 300)], [100], [0.0, 30.0], [1, 2], [0.33, 1.0])
)
def test_mask_along_axis(self, shape, mask_param, mask_value, axis, p):
)
def test_mask_along_axis(self, shape, mask_param, mask_value, last_axis, p):
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
# last_axis = 1 means the last axis; 2 means the second-to-last axis.
axis = len(shape) - last_axis
if p != 1.0:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis, p=p)
else:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2
other_axis = axis - 1 if last_axis == 1 else axis + 1
masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor")
den = 1
for i in range(len(shape) - 2):
den *= mask_specgram.size(i)
num_masked_columns = torch.div(num_masked_columns, den, rounding_mode="floor")
if p != 1.0:
mask_param = min(mask_param, int(specgram.shape[axis] * p))
......
......@@ -25,7 +25,6 @@ class Tester(common_utils.TorchaudioTestCase):
return waveform / factor
def test_mu_law_companding(self):
quantization_channels = 256
waveform = self.waveform.clone()
......
......@@ -361,3 +361,52 @@ class TransformsTestBase(TestBaseMixin):
deemphasis = T.Deemphasis(coeff=coeff).to(dtype=self.dtype, device=self.device)
deemphasized = deemphasis(preemphasized)
self.assertEqual(deemphasized, waveform)
@nested_params(
[(100, 200), (5, 10, 20), (50, 50, 100, 200)],
)
def test_time_masking(self, input_shape):
transform = T.TimeMasking(time_mask_param=5)
# Genearte a specgram tensor containing 1's only, for the ease of testing.
specgram = torch.ones(*input_shape)
masked = transform(specgram)
dim = len(input_shape)
# Across the axis (dim-1) where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
m_masked = torch.mean(masked, dim - 1)
self.assertEqual(torch.var(m_masked), 0)
self.assertTrue(torch.mean(m_masked) > 0)
self.assertTrue(torch.mean(m_masked) < 1)
# Across all other dimensions, the mean tensor should contain at least
# one zero element, and all non-zero elements should be 1.
for axis in range(dim - 1):
unmasked_axis_mean = torch.mean(masked, axis)
self.assertTrue(0 in unmasked_axis_mean)
self.assertFalse(False in torch.eq(unmasked_axis_mean[unmasked_axis_mean != 0], 1))
@nested_params(
[(100, 200), (5, 10, 20), (50, 50, 100, 200)],
)
def test_freq_masking(self, input_shape):
transform = T.FrequencyMasking(freq_mask_param=5)
# Genearte a specgram tensor containing 1's only, for the ease of testing.
specgram = torch.ones(*input_shape)
masked = transform(specgram)
dim = len(input_shape)
# Across the axis (dim-2) where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
m_masked = torch.mean(masked, dim - 2)
self.assertEqual(torch.var(m_masked), 0)
self.assertTrue(torch.mean(m_masked) > 0)
self.assertTrue(torch.mean(m_masked) < 1)
# Across all other dimensions, the mean tensor should contain at least
# one zero element, and all non-zero elements should be 1.
for axis in range(dim):
if axis != dim - 2:
unmasked_axis_mean = torch.mean(masked, axis)
self.assertTrue(0 in unmasked_axis_mean)
self.assertFalse(False in torch.eq(unmasked_axis_mean[unmasked_axis_mean != 0], 1))
......@@ -825,18 +825,25 @@ def mask_along_axis_iid(
``max_v = min(mask_param, floor(specgrams.size(axis) * p))`` otherwise.
Args:
specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
specgrams (Tensor): Real spectrograms `(..., freq, time)`, with at least 3 dimensions.
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
axis (int): Axis to apply masking on, which should be the one of the last two dimensions.
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
Returns:
Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
Tensor: Masked spectrograms with the same dimensions as input specgrams Tensor`
"""
if axis not in [2, 3]:
raise ValueError("Only Frequency and Time masking are supported")
dim = specgrams.dim()
if dim < 3:
raise ValueError(f"Spectrogram must have at least three dimensions ({dim} given).")
if axis not in [dim - 2, dim - 1]:
raise ValueError(
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
)
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
......@@ -848,8 +855,8 @@ def mask_along_axis_iid(
device = specgrams.device
dtype = specgrams.dtype
value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
value = torch.rand(specgrams.shape[: (dim - 2)], device=device, dtype=dtype) * mask_param
min_value = torch.rand(specgrams.shape[: (dim - 2)], device=device, dtype=dtype) * (specgrams.size(axis) - value)
# Create broadcastable mask
mask_start = min_value.long()[..., None, None]
......@@ -879,24 +886,31 @@ def mask_along_axis(
Mask will be applied from indices ``[v_0, v_0 + v)``,
where ``v`` is sampled from ``uniform(0, max_v)`` and
``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
``v_0`` from ``uniform(0, specgram.size(axis) - v)``, with
``max_v = mask_param`` when ``p = 1.0`` and
``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
``max_v = min(mask_param, floor(specgram.size(axis) * p))``
otherwise.
All examples will have the same mask interval.
Args:
specgram (Tensor): Real spectrogram `(channel, freq, time)`
specgram (Tensor): Real spectrograms `(..., freq, time)`, with at least 2 dimensions.
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
axis (int): Axis to apply masking on, which should be the one of the last two dimensions.
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
Returns:
Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
Tensor: Masked spectrograms with the same dimensions as input specgram Tensor
"""
if axis not in [1, 2]:
raise ValueError("Only Frequency and Time masking are supported")
dim = specgram.dim()
if dim < 2:
raise ValueError(f"Spectrogram must have at least two dimensions (time and frequency) ({dim} given).")
if axis not in [dim - 2, dim - 1]:
raise ValueError(
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
)
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
......@@ -908,14 +922,17 @@ def mask_along_axis(
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
# After packing, specgram is a 3D tensor, and the axis corresponding to the to-be-masked dimension
# is now (axis - dim + 3), e.g. a tensor of shape (10, 2, 50, 10, 2) becomes a tensor of shape (1000, 10, 2).
value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
min_value = torch.rand(1) * (specgram.size(axis - dim + 3) - value)
mask_start = (min_value.long()).squeeze()
mask_end = (min_value.long() + value.long()).squeeze()
mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
mask = torch.arange(0, specgram.shape[axis - dim + 3], device=specgram.device, dtype=specgram.dtype)
mask = (mask >= mask_start) & (mask < mask_end)
if axis == 1:
# unsqueeze the mask if the axis is frequency
if axis == dim - 2:
mask = mask.unsqueeze(-1)
if mask_end - mask_start >= mask_param:
......@@ -1439,7 +1456,6 @@ def _get_sinc_resample_kernel(
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
raise Exception(
"Frequencies must be of integer type to ensure quality resampling computation. "
......
......@@ -1196,15 +1196,16 @@ class _AxisMasking(torch.nn.Module):
Args:
mask_param (int): Maximum possible length of the mask.
axis (int): What dimension the mask is applied on.
axis (int): What dimension the mask is applied on (assuming the tensor is 3D).
For frequency masking, axis = 1.
For time masking, axis = 2.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D.
This option is applicable only when the dimension of the input tensor is >= 3.
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
"""
__constants__ = ["mask_param", "axis", "iid_masks", "p"]
def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None:
super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
......@@ -1221,10 +1222,14 @@ class _AxisMasking(torch.nn.Module):
Tensor: Masked spectrogram of dimensions `(..., freq, time)`.
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1, p=self.p)
# self.axis + specgram.dim() - 3 gives the time/frequency dimension (last two dimensions)
# for input tensor for which the dimension is not 3.
if self.iid_masks:
return F.mask_along_axis_iid(
specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p
)
else:
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis, p=self.p)
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p)
class FrequencyMasking(_AxisMasking):
......@@ -1241,7 +1246,7 @@ class FrequencyMasking(_AxisMasking):
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool, optional): whether to apply different masks to each
example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
This option is applicable only when the input tensor >= 3D.
Example
>>> spectrogram = torchaudio.transforms.Spectrogram()
......@@ -1275,7 +1280,7 @@ class TimeMasking(_AxisMasking):
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool, optional): whether to apply different masks to each
example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
This option is applicable only when the input tensor >= 3D.
p (float, optional): maximum proportion of time steps that can be masked.
Must be within range [0.0, 1.0]. (Default: 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment