Commit 1ec7ff73 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add parameter p to TimeMasking (#2090)

Summary:
Adds parameter `p` to `TimeMasking` to allow for enforcing an upper bound on the proportion of time steps that it can mask. This behavior is consistent with the specifications provided in the SpecAugment paper (https://arxiv.org/abs/1904.08779).

Pull Request resolved: https://github.com/pytorch/audio/pull/2090

Reviewed By: carolineechen

Differential Revision: D33344772

Pulled By: hwangjeff

fbshipit-source-id: 6ff65f5304e489fa1c23e15c3d96b9946229fdcf
parent 896ade04
...@@ -328,10 +328,16 @@ class Functional(TestBaseMixin): ...@@ -328,10 +328,16 @@ class Functional(TestBaseMixin):
close_to_limit = decibels < 6.0207 close_to_limit = decibels < 6.0207
assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}" assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}"
@parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2]))) @parameterized.expand(
def test_mask_along_axis(self, shape, mask_param, mask_value, axis): list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2], [0.33, 1.0]))
)
def test_mask_along_axis(self, shape, mask_param, mask_value, axis, p):
torch.random.manual_seed(42) torch.random.manual_seed(42)
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device) specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
if p != 1.0:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis, p=p)
else:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2 other_axis = 1 if axis == 2 else 2
...@@ -340,14 +346,20 @@ class Functional(TestBaseMixin): ...@@ -340,14 +346,20 @@ class Functional(TestBaseMixin):
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum() num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor") num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor")
if p != 1.0:
mask_param = min(mask_param, int(specgram.shape[axis] * p))
assert mask_specgram.size() == specgram.size() assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param assert num_masked_columns < mask_param
@parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3]))) @parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3], [0.2, 1.0])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis): def test_mask_along_axis_iid(self, mask_param, mask_value, axis, p):
torch.random.manual_seed(42) torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device) specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
if p != 1.0:
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis, p=p)
else:
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
other_axis = 2 if axis == 3 else 3 other_axis = 2 if axis == 3 else 3
...@@ -355,6 +367,9 @@ class Functional(TestBaseMixin): ...@@ -355,6 +367,9 @@ class Functional(TestBaseMixin):
masked_columns = (mask_specgrams == mask_value).sum(other_axis) masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1) num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
if p != 1.0:
mask_param = min(mask_param, int(specgrams.shape[axis] * p))
assert mask_specgrams.size() == specgrams.size() assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
......
...@@ -168,6 +168,16 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -168,6 +168,16 @@ class AutogradTestMixin(TestBaseMixin):
deterministic_transform = _DeterministicWrapper(masking_transform(400, True)) deterministic_transform = _DeterministicWrapper(masking_transform(400, True))
self.assert_grad(deterministic_transform, [batch]) self.assert_grad(deterministic_transform, [batch])
def test_time_masking_p(self):
sample_rate = 8000
n_fft = 400
spectrogram = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
)
time_mask = T.TimeMasking(400, iid_masks=False, p=0.1)
deterministic_transform = _DeterministicWrapper(time_mask)
self.assert_grad(deterministic_transform, [spectrogram])
def test_spectral_centroid(self): def test_spectral_centroid(self):
sample_rate = 8000 sample_rate = 8000
transform = T.SpectralCentroid(sample_rate=sample_rate) transform = T.SpectralCentroid(sample_rate=sample_rate)
......
...@@ -704,16 +704,32 @@ def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor) ...@@ -704,16 +704,32 @@ def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor)
return complex_specgrams_stretch return complex_specgrams_stretch
def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor: def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
if p == 1.0:
return mask_param
else:
return min(mask_param, int(axis_length * p))
def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
mask_value: float,
axis: int,
p: float = 1.0,
) -> Tensor:
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, max_v)`` and ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
``max_v = mask_param`` when ``p = 1.0`` and ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
otherwise.
Args: Args:
specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)` specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
Returns: Returns:
Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)` Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
...@@ -722,6 +738,13 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a ...@@ -722,6 +738,13 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a
if axis not in [2, 3]: if axis not in [2, 3]:
raise ValueError("Only Frequency and Time masking are supported") raise ValueError("Only Frequency and Time masking are supported")
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
mask_param = _get_mask_param(mask_param, p, specgrams.shape[axis])
if mask_param < 1:
return specgrams
device = specgrams.device device = specgrams.device
dtype = specgrams.dtype dtype = specgrams.dtype
...@@ -729,8 +752,8 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a ...@@ -729,8 +752,8 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a
min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value) min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
# Create broadcastable mask # Create broadcastable mask
mask_start = min_value[..., None, None] mask_start = min_value.long()[..., None, None]
mask_end = (min_value + value)[..., None, None] mask_end = (min_value.long() + value.long())[..., None, None]
mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype) mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
# Per batch example masking # Per batch example masking
...@@ -741,17 +764,25 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a ...@@ -741,17 +764,25 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a
return specgrams return specgrams
def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor: def mask_along_axis(
specgram: Tensor,
mask_param: int,
mask_value: float,
axis: int,
p: float = 1.0,
) -> Tensor:
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, max_v)`` and ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
All examples will have the same mask interval. ``max_v = mask_param`` when ``p = 1.0`` and ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
otherwise. All examples will have the same mask interval.
Args: Args:
specgram (Tensor): Real spectrogram `(channel, freq, time)` specgram (Tensor): Real spectrogram `(channel, freq, time)`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
Returns: Returns:
Tensor: Masked spectrogram of dimensions `(channel, freq, time)` Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
...@@ -759,6 +790,13 @@ def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis: ...@@ -759,6 +790,13 @@ def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis:
if axis not in [1, 2]: if axis not in [1, 2]:
raise ValueError("Only Frequency and Time masking are supported") raise ValueError("Only Frequency and Time masking are supported")
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
mask_param = _get_mask_param(mask_param, p, specgram.shape[axis])
if mask_param < 1:
return specgram
# pack batch # pack batch
shape = specgram.size() shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:])) specgram = specgram.reshape([-1] + list(shape[-2:]))
......
...@@ -1110,15 +1110,17 @@ class _AxisMasking(torch.nn.Module): ...@@ -1110,15 +1110,17 @@ class _AxisMasking(torch.nn.Module):
axis (int): What dimension the mask is applied on. axis (int): What dimension the mask is applied on.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension. iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D. This option is applicable only when the input tensor is 4D.
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
""" """
__constants__ = ["mask_param", "axis", "iid_masks"] __constants__ = ["mask_param", "axis", "iid_masks", "p"]
def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None: def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None:
super(_AxisMasking, self).__init__() super(_AxisMasking, self).__init__()
self.mask_param = mask_param self.mask_param = mask_param
self.axis = axis self.axis = axis
self.iid_masks = iid_masks self.iid_masks = iid_masks
self.p = p
def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor: def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
r""" r"""
...@@ -1131,9 +1133,9 @@ class _AxisMasking(torch.nn.Module): ...@@ -1131,9 +1133,9 @@ class _AxisMasking(torch.nn.Module):
""" """
# if iid_masks flag marked and specgram has a batch dimension # if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4: if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1, p=self.p)
else: else:
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis, p=self.p)
class FrequencyMasking(_AxisMasking): class FrequencyMasking(_AxisMasking):
...@@ -1177,6 +1179,8 @@ class TimeMasking(_AxisMasking): ...@@ -1177,6 +1179,8 @@ class TimeMasking(_AxisMasking):
iid_masks (bool, optional): whether to apply different masks to each iid_masks (bool, optional): whether to apply different masks to each
example/channel in the batch. (Default: ``False``) example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D. This option is applicable only when the input tensor is 4D.
p (float, optional): maximum proportion of time steps that can be masked.
Must be within range [0.0, 1.0]. (Default: 1.0)
Example Example
>>> spectrogram = torchaudio.transforms.Spectrogram() >>> spectrogram = torchaudio.transforms.Spectrogram()
...@@ -1192,8 +1196,10 @@ class TimeMasking(_AxisMasking): ...@@ -1192,8 +1196,10 @@ class TimeMasking(_AxisMasking):
:alt: The spectrogram masked along time axis :alt: The spectrogram masked along time axis
""" """
def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None: def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)
class Vol(torch.nn.Module): class Vol(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment