Unverified Commit 7fd5fce4 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Ensure axis masking operations are not in-place (#1481)

It was reported in #1478 that spectrogram masking operations were done in-place and modified the original input tensors. This PR fixes this behavior and adds tests to ensure that the input tensor is not changed.
parent b540e5d1
......@@ -227,6 +227,38 @@ class Functional(TestBaseMixin):
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
)
def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis):
"""mask_along_axis should not alter original input Tensor
Test is run 5 times to bound the probability of no masking occurring to 1e-10
See https://github.com/pytorch/audio/issues/1478
"""
torch.random.manual_seed(42)
for _ in range(5):
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
specgram_copy = specgram.clone()
F.mask_along_axis(specgram, mask_param, mask_value, axis)
self.assertEqual(specgram, specgram_copy)
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):
"""mask_along_axis_iid should not alter original input Tensor
Test is run 5 times to bound the probability of no masking occurring to 1e-10
See https://github.com/pytorch/audio/issues/1478
"""
torch.random.manual_seed(42)
for _ in range(5):
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
specgrams_copy = specgrams.clone()
F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
self.assertEqual(specgrams, specgrams_copy)
class FunctionalComplex(TestBaseMixin):
complex_dtype = None
......
......@@ -13,7 +13,7 @@ from torchaudio_unittest.common_utils import (
class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
"""Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
......@@ -21,8 +21,12 @@ class Functional(TempDirMixin, TestBaseMixin):
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
torch.random.manual_seed(40)
output = func(tensor)
torch.random.manual_seed(40)
ts_output = ts_func(tensor)
if shape_only:
ts_output = ts_output.shape
output = output.shape
......
......@@ -746,7 +746,7 @@ def mask_along_axis_iid(
# Per batch example masking
specgrams = specgrams.transpose(axis, -1)
specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
specgrams = specgrams.transpose(axis, -1)
return specgrams
......@@ -772,24 +772,25 @@ def mask_along_axis(
Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
if axis != 1 and axis != 2:
raise ValueError('Only Frequency and Time masking are supported')
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
mask_start = (min_value.long()).squeeze()
mask_end = (min_value.long() + value.long()).squeeze()
mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
mask = (mask >= mask_start) & (mask < mask_end)
if axis == 1:
mask = mask.unsqueeze(-1)
assert mask_end - mask_start < mask_param
if axis == 1:
specgram[:, mask_start:mask_end] = mask_value
elif axis == 2:
specgram[:, :, mask_start:mask_end] = mask_value
else:
raise ValueError('Only Frequency and Time masking are supported')
specgram = specgram.masked_fill(mask, mask_value)
# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment