"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0bb9b9147a1cf6dec81b7130979d33f25c5720a9"
Commit 5af309d3 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Make lengths optional for speed functions and modules (#3072)

Summary:
Makes lengths input optional for `torchaudio.functional.speed`, `torchaudio.transforms.Speed`, and `torchaudio.transforms.SpeedPerturbation`.

Pull Request resolved: https://github.com/pytorch/audio/pull/3072

Reviewed By: nateanl, mthrok

Differential Revision: D43371406

Pulled By: hwangjeff

fbshipit-source-id: ecb38bcc2bfff5c5a396a37eff238b22238e795a
parent e663095c
...@@ -360,7 +360,7 @@ class Autograd(TestBaseMixin): ...@@ -360,7 +360,7 @@ class Autograd(TestBaseMixin):
T = 200 T = 200
waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device) lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
self.assert_grad(F.speed, (waveform, lengths, 1000, 1.1), enable_all_grad=False) self.assert_grad(F.speed, (waveform, 1000, 1.1, lengths), enable_all_grad=False)
def test_preemphasis(self): def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True) waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
......
...@@ -459,12 +459,12 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -459,12 +459,12 @@ class TestFunctional(common_utils.TorchaudioTestCase):
unbatched_input = [torch.ones((int(length),)) * 1.0 for length in input_lengths] unbatched_input = [torch.ones((int(length),)) * 1.0 for length in input_lengths]
batched_input = torch.nn.utils.rnn.pad_sequence(unbatched_input, batch_first=True) batched_input = torch.nn.utils.rnn.pad_sequence(unbatched_input, batch_first=True)
output, output_lengths = F.speed(batched_input, input_lengths, orig_freq=orig_freq, factor=factor) output, output_lengths = F.speed(batched_input, orig_freq=orig_freq, factor=factor, lengths=input_lengths)
unbatched_output = [] unbatched_output = []
unbatched_output_lengths = [] unbatched_output_lengths = []
for idx in range(len(unbatched_input)): for idx in range(len(unbatched_input)):
w, l = F.speed(unbatched_input[idx], input_lengths[idx], orig_freq=orig_freq, factor=factor) w, l = F.speed(unbatched_input[idx], orig_freq=orig_freq, factor=factor, lengths=input_lengths[idx])
unbatched_output.append(w) unbatched_output.append(w)
unbatched_output_lengths.append(l) unbatched_output_lengths.append(l)
......
...@@ -1042,14 +1042,12 @@ class Functional(TestBaseMixin): ...@@ -1042,14 +1042,12 @@ class Functional(TestBaseMixin):
T = 1000 T = 1000
waveform = torch.rand(*leading_dims, T) waveform = torch.rand(*leading_dims, T)
lengths = torch.randint(1, 1000, leading_dims) lengths = torch.randint(1, 1000, leading_dims)
actual_waveform, actual_lengths = F.speed(waveform, lengths, orig_freq=1000, factor=1.0) actual_waveform, actual_lengths = F.speed(waveform, orig_freq=1000, factor=1.0, lengths=lengths)
self.assertEqual(waveform, actual_waveform) self.assertEqual(waveform, actual_waveform)
self.assertEqual(lengths, actual_lengths) self.assertEqual(lengths, actual_lengths)
@nested_params( @nested_params([0.8, 1.1, 1.2], [True, False])
[0.8, 1.1, 1.2], def test_speed_accuracy(self, factor, use_lengths):
)
def test_speed_accuracy(self, factor):
"""sinusoidal waveform is properly compressed by factor""" """sinusoidal waveform is properly compressed by factor"""
n_to_trim = 20 n_to_trim = 20
...@@ -1057,10 +1055,18 @@ class Functional(TestBaseMixin): ...@@ -1057,10 +1055,18 @@ class Functional(TestBaseMixin):
freq = 2 freq = 2
times = torch.arange(0, 5, 1.0 / sample_rate) times = torch.arange(0, 5, 1.0 / sample_rate)
waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype) waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype)
lengths = torch.tensor([waveform.size(1)])
output, output_lengths = F.speed(waveform, lengths, orig_freq=sample_rate, factor=factor) if use_lengths:
self.assertEqual(output.size(1), output_lengths[0]) lengths = torch.tensor([waveform.size(1)])
else:
lengths = None
output, output_lengths = F.speed(waveform, orig_freq=sample_rate, factor=factor, lengths=lengths)
if use_lengths:
self.assertEqual(output.size(1), output_lengths[0])
else:
self.assertEqual(None, output_lengths)
new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate) new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate)
expected_waveform = torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype) expected_waveform = torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype)
......
...@@ -785,12 +785,16 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -785,12 +785,16 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(F.add_noise, (waveform, noise, snr, lengths)) self._assert_consistency(F.add_noise, (waveform, noise, snr, lengths))
def test_speed(self): @common_utils.nested_params([True, False])
def test_speed(self, use_lengths):
leading_dims = (3, 2) leading_dims = (3, 2)
T = 200 T = 200
waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device) if use_lengths:
self._assert_consistency(F.speed, (waveform, lengths, 1000, 1.1)) lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
else:
lengths = None
self._assert_consistency(F.speed, (waveform, 1000, 1.1, lengths))
def test_preemphasis(self): def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype) waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
......
...@@ -207,22 +207,32 @@ class Transforms(TestBaseMixin): ...@@ -207,22 +207,32 @@ class Transforms(TestBaseMixin):
ts_output = torch_script(convolve)(x, y) ts_output = torch_script(convolve)(x, y)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_speed(self): @common_utils.nested_params([True, False])
def test_speed(self, use_lengths):
leading_dims = (3, 2) leading_dims = (3, 2)
time = 200 time = 200
waveform = torch.rand(*leading_dims, time, dtype=self.dtype, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, time, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.randint(1, time, leading_dims, dtype=self.dtype, device=self.device)
if use_lengths:
lengths = torch.randint(1, time, leading_dims, dtype=self.dtype, device=self.device)
else:
lengths = None
speed = T.Speed(1000, 0.9).to(self.device, self.dtype) speed = T.Speed(1000, 0.9).to(self.device, self.dtype)
output = speed(waveform, lengths) output = speed(waveform, lengths)
ts_output = torch_script(speed)(waveform, lengths) ts_output = torch_script(speed)(waveform, lengths)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_speed_perturbation(self): @common_utils.nested_params([True, False])
def test_speed_perturbation(self, use_lengths):
leading_dims = (3, 2) leading_dims = (3, 2)
time = 200 time = 200
waveform = torch.rand(*leading_dims, time, dtype=self.dtype, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, time, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.randint(1, time, leading_dims, dtype=self.dtype, device=self.device)
if use_lengths:
lengths = torch.randint(1, time, leading_dims, dtype=self.dtype, device=self.device)
else:
lengths = None
speed = T.SpeedPerturbation(1000, [0.9]).to(self.device, self.dtype) speed = T.SpeedPerturbation(1000, [0.9]).to(self.device, self.dtype)
output = speed(waveform, lengths) output = speed(waveform, lengths)
......
...@@ -224,10 +224,8 @@ class TransformsTestBase(TestBaseMixin): ...@@ -224,10 +224,8 @@ class TransformsTestBase(TestBaseMixin):
self.assertEqual(waveform, actual_waveform) self.assertEqual(waveform, actual_waveform)
self.assertEqual(lengths, actual_lengths) self.assertEqual(lengths, actual_lengths)
@nested_params( @nested_params([0.8, 1.1, 1.2], [True, False])
[0.8, 1.1, 1.2], def test_speed_accuracy(self, factor, use_lengths):
)
def test_speed_accuracy(self, factor):
"""sinusoidal waveform is properly compressed by factor""" """sinusoidal waveform is properly compressed by factor"""
n_to_trim = 20 n_to_trim = 20
...@@ -235,11 +233,19 @@ class TransformsTestBase(TestBaseMixin): ...@@ -235,11 +233,19 @@ class TransformsTestBase(TestBaseMixin):
freq = 2 freq = 2
times = torch.arange(0, 5, 1.0 / sample_rate) times = torch.arange(0, 5, 1.0 / sample_rate)
waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype) waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype)
lengths = torch.tensor([waveform.size(1)])
if use_lengths:
lengths = torch.tensor([waveform.size(1)])
else:
lengths = None
speed = T.Speed(sample_rate, factor).to(self.device, self.dtype) speed = T.Speed(sample_rate, factor).to(self.device, self.dtype)
output, output_lengths = speed(waveform, lengths) output, output_lengths = speed(waveform, lengths)
self.assertEqual(output.size(1), output_lengths[0])
if use_lengths:
self.assertEqual(output.size(1), output_lengths[0])
else:
self.assertEqual(None, output_lengths)
new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate) new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate)
expected_waveform = torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype) expected_waveform = torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype)
......
...@@ -2492,8 +2492,8 @@ def add_noise( ...@@ -2492,8 +2492,8 @@ def add_noise(
def speed( def speed(
waveform: torch.Tensor, lengths: torch.Tensor, orig_freq: int, factor: float waveform: torch.Tensor, orig_freq: int, factor: float, lengths: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
r"""Adjusts waveform speed. r"""Adjusts waveform speed.
.. devices:: CPU CUDA .. devices:: CPU CUDA
...@@ -2502,17 +2502,19 @@ def speed( ...@@ -2502,17 +2502,19 @@ def speed(
Args: Args:
waveform (torch.Tensor): Input signals, with shape `(..., time)`. waveform (torch.Tensor): Input signals, with shape `(..., time)`.
lengths (torch.Tensor): Valid lengths of signals in ``waveform``, with shape `(...)`.
orig_freq (int): Original frequency of the signals in ``waveform``. orig_freq (int): Original frequency of the signals in ``waveform``.
factor (float): Factor by which to adjust speed of input. Values greater than 1.0 factor (float): Factor by which to adjust speed of input. Values greater than 1.0
compress ``waveform`` in time, whereas values less than 1.0 stretch ``waveform`` in time. compress ``waveform`` in time, whereas values less than 1.0 stretch ``waveform`` in time.
lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform``, with shape `(...)`.
If ``None``, all elements in ``waveform`` are treated as valid. (Default: ``None``)
Returns: Returns:
(torch.Tensor, torch.Tensor): (torch.Tensor, torch.Tensor or None):
torch.Tensor torch.Tensor
Speed-adjusted waveform, with shape `(..., new_time).` Speed-adjusted waveform, with shape `(..., new_time).`
torch.Tensor torch.Tensor or None
Valid lengths of signals in speed-adjusted waveform, with shape `(...)`. If ``lengths`` is not ``None``, valid lengths of signals in speed-adjusted waveform,
with shape `(...)`; otherwise, ``None``.
""" """
source_sample_rate = int(factor * orig_freq) source_sample_rate = int(factor * orig_freq)
...@@ -2522,9 +2524,12 @@ def speed( ...@@ -2522,9 +2524,12 @@ def speed(
source_sample_rate = source_sample_rate // gcd source_sample_rate = source_sample_rate // gcd
target_sample_rate = target_sample_rate // gcd target_sample_rate = target_sample_rate // gcd
return resample(waveform, source_sample_rate, target_sample_rate), torch.ceil( if lengths is None:
lengths * target_sample_rate / source_sample_rate out_lengths = None
).to(lengths.dtype) else:
out_lengths = torch.ceil(lengths * target_sample_rate / source_sample_rate).to(lengths.dtype)
return resample(waveform, source_sample_rate, target_sample_rate), out_lengths
def preemphasis(waveform, coeff: float = 0.97) -> torch.Tensor: def preemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
......
...@@ -1927,23 +1927,28 @@ class Speed(torch.nn.Module): ...@@ -1927,23 +1927,28 @@ class Speed(torch.nn.Module):
self.source_sample_rate, self.target_sample_rate = _source_target_sample_rate(orig_freq, factor) self.source_sample_rate, self.target_sample_rate = _source_target_sample_rate(orig_freq, factor)
self.resampler = Resample(orig_freq=self.source_sample_rate, new_freq=self.target_sample_rate) self.resampler = Resample(orig_freq=self.source_sample_rate, new_freq=self.target_sample_rate)
def forward(self, waveform, lengths) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, waveform, lengths: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
r""" r"""
Args: Args:
waveform (torch.Tensor): Input signals, with shape `(..., time)`. waveform (torch.Tensor): Input signals, with shape `(..., time)`.
lengths (torch.Tensor): Valid lengths of signals in ``waveform``, with shape `(...)`. lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform``, with shape `(...)`.
If ``None``, all elements in ``waveform`` are treated as valid. (Default: ``None``)
Returns: Returns:
(torch.Tensor, torch.Tensor): (torch.Tensor, torch.Tensor or None):
torch.Tensor torch.Tensor
Speed-adjusted waveform, with shape `(..., new_time).` Speed-adjusted waveform, with shape `(..., new_time).`
torch.Tensor torch.Tensor or None
Valid lengths of signals in speed-adjusted waveform, with shape `(...)`. If ``lengths`` is not ``None``, valid lengths of signals in speed-adjusted waveform,
with shape `(...)`; otherwise, ``None``.
""" """
return (
self.resampler(waveform), if lengths is None:
torch.ceil(lengths * self.target_sample_rate / self.source_sample_rate).to(lengths.dtype), out_lengths = None
) else:
out_lengths = torch.ceil(lengths * self.target_sample_rate / self.source_sample_rate).to(lengths.dtype)
return self.resampler(waveform), out_lengths
class SpeedPerturbation(torch.nn.Module): class SpeedPerturbation(torch.nn.Module):
...@@ -1973,18 +1978,22 @@ class SpeedPerturbation(torch.nn.Module): ...@@ -1973,18 +1978,22 @@ class SpeedPerturbation(torch.nn.Module):
self.speeders = torch.nn.ModuleList([Speed(orig_freq=orig_freq, factor=factor) for factor in factors]) self.speeders = torch.nn.ModuleList([Speed(orig_freq=orig_freq, factor=factor) for factor in factors])
def forward(self, waveform: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self, waveform: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
r""" r"""
Args: Args:
waveform (torch.Tensor): Input signals, with shape `(..., time)`. waveform (torch.Tensor): Input signals, with shape `(..., time)`.
lengths (torch.Tensor): Valid lengths of signals in ``waveform``, with shape `(...)`. lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform``, with shape `(...)`.
If ``None``, all elements in ``waveform`` are treated as valid. (Default: ``None``)
Returns: Returns:
(torch.Tensor, torch.Tensor): (torch.Tensor, torch.Tensor or None):
torch.Tensor torch.Tensor
Speed-adjusted waveform, with shape `(..., new_time).` Speed-adjusted waveform, with shape `(..., new_time).`
torch.Tensor torch.Tensor or None
Valid lengths of signals in speed-adjusted waveform, with shape `(...)`. If ``lengths`` is not ``None``, valid lengths of signals in speed-adjusted waveform,
with shape `(...)`; otherwise, ``None``.
""" """
idx = int(torch.randint(len(self.speeders), ())) idx = int(torch.randint(len(self.speeders), ()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment