"vscode:/vscode.git/clone" did not exist on "5d00e2b4f794cf12cc89cfc0262bdede353a27a4"
Commit 1ec7ff73 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add parameter p to TimeMasking (#2090)

Summary:
Adds parameter `p` to `TimeMasking` to allow for enforcing an upper bound on the proportion of time steps that it can mask. This behavior is consistent with the specifications provided in the SpecAugment paper (https://arxiv.org/abs/1904.08779).

Pull Request resolved: https://github.com/pytorch/audio/pull/2090

Reviewed By: carolineechen

Differential Revision: D33344772

Pulled By: hwangjeff

fbshipit-source-id: 6ff65f5304e489fa1c23e15c3d96b9946229fdcf
parent 896ade04
......@@ -328,10 +328,16 @@ class Functional(TestBaseMixin):
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}"
@parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2])))
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
@parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2], [0.33, 1.0]))
)
def test_mask_along_axis(self, shape, mask_param, mask_value, axis, p):
torch.random.manual_seed(42)
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
if p != 1.0:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis, p=p)
else:
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2
......@@ -340,14 +346,20 @@ class Functional(TestBaseMixin):
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor")
if p != 1.0:
mask_param = min(mask_param, int(specgram.shape[axis] * p))
assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param
@parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
@parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3], [0.2, 1.0])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis, p):
torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
if p != 1.0:
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis, p=p)
else:
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
other_axis = 2 if axis == 3 else 3
......@@ -355,6 +367,9 @@ class Functional(TestBaseMixin):
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
if p != 1.0:
mask_param = min(mask_param, int(specgrams.shape[axis] * p))
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
......
......@@ -168,6 +168,16 @@ class AutogradTestMixin(TestBaseMixin):
deterministic_transform = _DeterministicWrapper(masking_transform(400, True))
self.assert_grad(deterministic_transform, [batch])
def test_time_masking_p(self):
sample_rate = 8000
n_fft = 400
spectrogram = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
)
time_mask = T.TimeMasking(400, iid_masks=False, p=0.1)
deterministic_transform = _DeterministicWrapper(time_mask)
self.assert_grad(deterministic_transform, [spectrogram])
def test_spectral_centroid(self):
sample_rate = 8000
transform = T.SpectralCentroid(sample_rate=sample_rate)
......
......@@ -704,16 +704,32 @@ def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor)
return complex_specgrams_stretch
def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor:
def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
if p == 1.0:
return mask_param
else:
return min(mask_param, int(axis_length * p))
def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
mask_value: float,
axis: int,
p: float = 1.0,
) -> Tensor:
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
``v`` is sampled from ``uniform(0, max_v)`` and ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
``max_v = mask_param`` when ``p = 1.0`` and ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
otherwise.
Args:
specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
Returns:
Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
......@@ -722,6 +738,13 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a
if axis not in [2, 3]:
raise ValueError("Only Frequency and Time masking are supported")
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
mask_param = _get_mask_param(mask_param, p, specgrams.shape[axis])
if mask_param < 1:
return specgrams
device = specgrams.device
dtype = specgrams.dtype
......@@ -729,8 +752,8 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a
min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
# Create broadcastable mask
mask_start = min_value[..., None, None]
mask_end = (min_value + value)[..., None, None]
mask_start = min_value.long()[..., None, None]
mask_end = (min_value.long() + value.long())[..., None, None]
mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
# Per batch example masking
......@@ -741,17 +764,25 @@ def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, a
return specgrams
def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor:
def mask_along_axis(
specgram: Tensor,
mask_param: int,
mask_value: float,
axis: int,
p: float = 1.0,
) -> Tensor:
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.
``v`` is sampled from ``uniform(0, max_v)`` and ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
``max_v = mask_param`` when ``p = 1.0`` and ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
otherwise. All examples will have the same mask interval.
Args:
specgram (Tensor): Real spectrogram `(channel, freq, time)`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
Returns:
Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
......@@ -759,6 +790,13 @@ def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis:
if axis not in [1, 2]:
raise ValueError("Only Frequency and Time masking are supported")
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
mask_param = _get_mask_param(mask_param, p, specgram.shape[axis])
if mask_param < 1:
return specgram
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
......
......@@ -1110,15 +1110,17 @@ class _AxisMasking(torch.nn.Module):
axis (int): What dimension the mask is applied on.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D.
p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
"""
__constants__ = ["mask_param", "axis", "iid_masks"]
__constants__ = ["mask_param", "axis", "iid_masks", "p"]
def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None:
super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks
self.p = p
def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
r"""
......@@ -1131,9 +1133,9 @@ class _AxisMasking(torch.nn.Module):
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1, p=self.p)
else:
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis, p=self.p)
class FrequencyMasking(_AxisMasking):
......@@ -1177,6 +1179,8 @@ class TimeMasking(_AxisMasking):
iid_masks (bool, optional): whether to apply different masks to each
example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
p (float, optional): maximum proportion of time steps that can be masked.
Must be within range [0.0, 1.0]. (Default: 1.0)
Example
>>> spectrogram = torchaudio.transforms.Spectrogram()
......@@ -1192,8 +1196,10 @@ class TimeMasking(_AxisMasking):
:alt: The spectrogram masked along time axis
"""
def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0) -> None:
if not 0.0 <= p <= 1.0:
raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)
class Vol(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment