Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
import torch import torch
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl
class EmformerFloat32CPUTest(EmformerTestImpl, PytorchTestCase): class EmformerFloat32CPUTest(EmformerTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio.prototype import Emformer from torchaudio.prototype import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class EmformerTestImpl(TestBaseMixin): class EmformerTestImpl(TestBaseMixin):
...@@ -18,9 +18,7 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -18,9 +18,7 @@ class EmformerTestImpl(TestBaseMixin):
return emformer return emformer
def _gen_inputs(self, input_dim, batch_size, num_frames, right_context_length): def _gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to( input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
device=self.device, dtype=self.dtype
)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to( lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
...@@ -34,9 +32,7 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -34,9 +32,7 @@ class EmformerTestImpl(TestBaseMixin):
right_context_length = 1 right_context_length = 1
emformer = self._gen_model(input_dim, right_context_length) emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs( input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input_dim, batch_size, num_frames, right_context_length
)
scripted = torch_script(emformer) scripted = torch_script(emformer)
ref_out, ref_len = emformer(input, lengths) ref_out, ref_len = emformer(input, lengths)
...@@ -59,9 +55,7 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -59,9 +55,7 @@ class EmformerTestImpl(TestBaseMixin):
for _ in range(3): for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0) input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state) ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state)
scripted_out, scripted_len, scripted_state = scripted.infer( scripted_out, scripted_len, scripted_state = scripted.infer(input, lengths, scripted_state)
input, lengths, scripted_state
)
self.assertEqual(ref_out, scripted_out) self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_len, scripted_len) self.assertEqual(ref_len, scripted_len)
self.assertEqual(ref_state, scripted_state) self.assertEqual(ref_state, scripted_state)
...@@ -74,15 +68,11 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -74,15 +68,11 @@ class EmformerTestImpl(TestBaseMixin):
right_context_length = 9 right_context_length = 9
emformer = self._gen_model(input_dim, right_context_length) emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs( input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input_dim, batch_size, num_frames, right_context_length
)
output, output_lengths = emformer(input, lengths) output, output_lengths = emformer(input, lengths)
self.assertEqual( self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape)
(batch_size, num_frames - right_context_length, input_dim), output.shape
)
self.assertEqual((batch_size,), output_lengths.shape) self.assertEqual((batch_size,), output_lengths.shape)
def test_output_shape_infer(self): def test_output_shape_infer(self):
...@@ -98,9 +88,7 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -98,9 +88,7 @@ class EmformerTestImpl(TestBaseMixin):
for _ in range(3): for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0) input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
output, output_lengths, state = emformer.infer(input, lengths, state) output, output_lengths, state = emformer.infer(input, lengths, state)
self.assertEqual( self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape)
(batch_size, num_frames - right_context_length, input_dim), output.shape
)
self.assertEqual((batch_size,), output_lengths.shape) self.assertEqual((batch_size,), output_lengths.shape)
def test_output_lengths_forward(self): def test_output_lengths_forward(self):
...@@ -111,9 +99,7 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -111,9 +99,7 @@ class EmformerTestImpl(TestBaseMixin):
right_context_length = 2 right_context_length = 2
emformer = self._gen_model(input_dim, right_context_length) emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs( input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input_dim, batch_size, num_frames, right_context_length
)
_, output_lengths = emformer(input, lengths) _, output_lengths = emformer(input, lengths)
self.assertEqual(lengths, output_lengths) self.assertEqual(lengths, output_lengths)
...@@ -127,6 +113,4 @@ class EmformerTestImpl(TestBaseMixin): ...@@ -127,6 +113,4 @@ class EmformerTestImpl(TestBaseMixin):
emformer = self._gen_model(input_dim, right_context_length).eval() emformer = self._gen_model(input_dim, right_context_length).eval()
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0) input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, 0)
_, output_lengths, _ = emformer.infer(input, lengths) _, output_lengths, _ = emformer.infer(input, lengths)
self.assertEqual( self.assertEqual(torch.clamp(lengths - right_context_length, min=0), output_lengths)
torch.clamp(lengths - right_context_length, min=0), output_lengths
)
import torch import torch
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl
class RNNTFloat32CPUTest(RNNTTestImpl, PytorchTestCase): class RNNTFloat32CPUTest(RNNTTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
class RNNTBeamSearchFloat32CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase): class RNNTBeamSearchFloat32CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio.prototype import RNNTBeamSearch, emformer_rnnt_model from torchaudio.prototype import RNNTBeamSearch, emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
...@@ -42,11 +41,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -42,11 +41,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
} }
def _get_model(self): def _get_model(self):
return ( return emformer_rnnt_model(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
emformer_rnnt_model(**self._get_model_config())
.to(device=self.device, dtype=self.dtype)
.eval()
)
def test_torchscript_consistency_forward(self): def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `forward`.""" r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `forward`."""
...@@ -62,12 +57,10 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -62,12 +57,10 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
blank_idx = num_symbols - 1 blank_idx = num_symbols - 1
beam_width = 5 beam_width = 5
input = torch.rand( input = torch.rand(batch_size, max_input_length + right_context_length, input_dim).to(
batch_size, max_input_length + right_context_length, input_dim device=self.device, dtype=self.dtype
).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
) )
lengths = torch.randint(1, max_input_length + 1, (batch_size,)).to(device=self.device, dtype=torch.int32)
model = self._get_model() model = self._get_model()
beam_search = RNNTBeamSearch(model, blank_idx) beam_search = RNNTBeamSearch(model, blank_idx)
...@@ -91,12 +84,10 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -91,12 +84,10 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
blank_idx = num_symbols - 1 blank_idx = num_symbols - 1
beam_width = 5 beam_width = 5
input = torch.rand( input = torch.rand(segment_length + right_context_length, input_dim).to(device=self.device, dtype=self.dtype)
segment_length + right_context_length, input_dim lengths = torch.randint(1, segment_length + right_context_length + 1, ()).to(
).to(device=self.device, dtype=self.dtype) device=self.device, dtype=torch.int32
lengths = torch.randint( )
1, segment_length + right_context_length + 1, ()
).to(device=self.device, dtype=torch.int32)
model = self._get_model() model = self._get_model()
...@@ -107,9 +98,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -107,9 +98,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
scripted = torch_script(beam_search) scripted = torch_script(beam_search)
res = beam_search.infer(input, lengths, beam_width, state=state, hypothesis=hypo) res = beam_search.infer(input, lengths, beam_width, state=state, hypothesis=hypo)
scripted_res = scripted.infer( scripted_res = scripted.infer(input, lengths, beam_width, state=scripted_state, hypothesis=scripted_hypo)
input, lengths, beam_width, state=scripted_state, hypothesis=scripted_hypo
)
self.assertEqual(res, scripted_res) self.assertEqual(res, scripted_res)
......
import torch import torch
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio.prototype.rnnt import emformer_rnnt_model from torchaudio.prototype.rnnt import emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class RNNTTestImpl(TestBaseMixin): class RNNTTestImpl(TestBaseMixin):
...@@ -45,11 +45,7 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -45,11 +45,7 @@ class RNNTTestImpl(TestBaseMixin):
} }
def _get_model(self): def _get_model(self):
return ( return emformer_rnnt_model(**self._get_model_config()).to(device=self.device, dtype=self.dtype).eval()
emformer_rnnt_model(**self._get_model_config())
.to(device=self.device, dtype=self.dtype)
.eval()
)
def _get_transcriber_input(self): def _get_transcriber_input(self):
input_config = self._get_input_config() input_config = self._get_input_config()
...@@ -59,12 +55,10 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -59,12 +55,10 @@ class RNNTTestImpl(TestBaseMixin):
right_context_length = input_config["right_context_length"] right_context_length = input_config["right_context_length"]
torch.random.manual_seed(31) torch.random.manual_seed(31)
input = torch.rand( input = torch.rand(batch_size, max_input_length + right_context_length, input_dim).to(
batch_size, max_input_length + right_context_length, input_dim device=self.device, dtype=self.dtype
).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
) )
lengths = torch.randint(1, max_input_length + 1, (batch_size,)).to(device=self.device, dtype=torch.int32)
return input, lengths return input, lengths
def _get_transcriber_streaming_input(self): def _get_transcriber_streaming_input(self):
...@@ -75,12 +69,12 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -75,12 +69,12 @@ class RNNTTestImpl(TestBaseMixin):
right_context_length = input_config["right_context_length"] right_context_length = input_config["right_context_length"]
torch.random.manual_seed(31) torch.random.manual_seed(31)
input = torch.rand( input = torch.rand(batch_size, segment_length + right_context_length, input_dim).to(
batch_size, segment_length + right_context_length, input_dim device=self.device, dtype=self.dtype
).to(device=self.device, dtype=self.dtype) )
lengths = torch.randint( lengths = torch.randint(1, segment_length + right_context_length + 1, (batch_size,)).to(
1, segment_length + right_context_length + 1, (batch_size,) device=self.device, dtype=torch.int32
).to(device=self.device, dtype=torch.int32) )
return input, lengths return input, lengths
def _get_predictor_input(self): def _get_predictor_input(self):
...@@ -90,12 +84,8 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -90,12 +84,8 @@ class RNNTTestImpl(TestBaseMixin):
max_target_length = input_config["max_target_length"] max_target_length = input_config["max_target_length"]
torch.random.manual_seed(31) torch.random.manual_seed(31)
input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to( input = torch.randint(0, num_symbols, (batch_size, max_target_length)).to(device=self.device, dtype=torch.int32)
device=self.device, dtype=torch.int32 lengths = torch.randint(1, max_target_length + 1, (batch_size,)).to(device=self.device, dtype=torch.int32)
)
lengths = torch.randint(1, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
return input, lengths return input, lengths
def _get_joiner_input(self): def _get_joiner_input(self):
...@@ -106,15 +96,13 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -106,15 +96,13 @@ class RNNTTestImpl(TestBaseMixin):
input_dim = input_config["encoding_dim"] input_dim = input_config["encoding_dim"]
torch.random.manual_seed(31) torch.random.manual_seed(31)
utterance_encodings = torch.rand( utterance_encodings = torch.rand(batch_size, joiner_max_input_length, input_dim).to(
batch_size, joiner_max_input_length, input_dim
).to(device=self.device, dtype=self.dtype)
utterance_lengths = torch.randint(
0, joiner_max_input_length + 1, (batch_size,)
).to(device=self.device, dtype=torch.int32)
target_encodings = torch.rand(batch_size, max_target_length, input_dim).to(
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
utterance_lengths = torch.randint(0, joiner_max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
target_encodings = torch.rand(batch_size, max_target_length, input_dim).to(device=self.device, dtype=self.dtype)
target_lengths = torch.randint(0, max_target_length + 1, (batch_size,)).to( target_lengths = torch.randint(0, max_target_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32 device=self.device, dtype=torch.int32
) )
...@@ -167,9 +155,7 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -167,9 +155,7 @@ class RNNTTestImpl(TestBaseMixin):
ref_state, scripted_state = None, None ref_state, scripted_state = None, None
for _ in range(2): for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.transcribe_streaming( ref_out, ref_lengths, ref_state = rnnt.transcribe_streaming(input, lengths, ref_state)
input, lengths, ref_state
)
( (
scripted_out, scripted_out,
scripted_lengths, scripted_lengths,
...@@ -190,9 +176,7 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -190,9 +176,7 @@ class RNNTTestImpl(TestBaseMixin):
ref_state, scripted_state = None, None ref_state, scripted_state = None, None
for _ in range(2): for _ in range(2):
ref_out, ref_lengths, ref_state = rnnt.predict(input, lengths, ref_state) ref_out, ref_lengths, ref_state = rnnt.predict(input, lengths, ref_state)
scripted_out, scripted_lengths, scripted_state = scripted.predict( scripted_out, scripted_lengths, scripted_state = scripted.predict(input, lengths, scripted_state)
input, lengths, scripted_state
)
self.assertEqual(ref_out, scripted_out) self.assertEqual(ref_out, scripted_out)
self.assertEqual(ref_lengths, scripted_lengths) self.assertEqual(ref_lengths, scripted_lengths)
self.assertEqual(ref_state, scripted_state) self.assertEqual(ref_state, scripted_state)
...@@ -234,9 +218,7 @@ class RNNTTestImpl(TestBaseMixin): ...@@ -234,9 +218,7 @@ class RNNTTestImpl(TestBaseMixin):
state = None state = None
for _ in range(2): for _ in range(2):
out, out_lengths, target_lengths, state = rnnt( out, out_lengths, target_lengths, state = rnnt(inputs, input_lengths, targets, target_lengths, state)
inputs, input_lengths, targets, target_lengths, state
)
self.assertEqual( self.assertEqual(
(batch_size, joiner_max_input_length, max_target_length, num_symbols), (batch_size, joiner_max_input_length, max_target_length, num_symbols),
out.shape, out.shape,
......
import json import json
from parameterized import param from parameterized import param
from torchaudio_unittest.common_utils import get_asset_path from torchaudio_unittest.common_utils import get_asset_path
...@@ -10,15 +9,15 @@ def name_func(func, _, params): ...@@ -10,15 +9,15 @@ def name_func(func, _, params):
args = "_".join([str(arg) for arg in params.args]) args = "_".join([str(arg) for arg in params.args])
else: else:
args = "_".join([str(arg) for arg in params.args[0]]) args = "_".join([str(arg) for arg in params.args[0]])
return f'{func.__name__}_{args}' return f"{func.__name__}_{args}"
def load_params(*paths): def load_params(*paths):
params = [] params = []
with open(get_asset_path(*paths), 'r') as file: with open(get_asset_path(*paths), "r") as file:
for line in file: for line in file:
data = json.loads(line) data = json.loads(line)
for effect in data['effects']: for effect in data["effects"]:
for i, arg in enumerate(effect): for i, arg in enumerate(effect):
if arg.startswith("<ASSET_DIR>"): if arg.startswith("<ASSET_DIR>"):
effect[i] = arg.replace("<ASSET_DIR>", get_asset_path()) effect[i] = arg.replace("<ASSET_DIR>", get_asset_path())
......
import os import os
import sys
import platform import platform
from unittest import skipIf import sys
from typing import List, Tuple
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from typing import List, Tuple
from unittest import skipIf
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
...@@ -20,6 +19,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -20,6 +19,7 @@ from torchaudio_unittest.common_utils import (
class RandomPerturbationFile(torch.utils.data.Dataset): class RandomPerturbationFile(torch.utils.data.Dataset):
"""Given flist, apply random speed perturbation""" """Given flist, apply random speed perturbation"""
def __init__(self, flist: List[str], sample_rate: int): def __init__(self, flist: List[str], sample_rate: int):
super().__init__() super().__init__()
self.flist = flist self.flist = flist
...@@ -29,11 +29,11 @@ class RandomPerturbationFile(torch.utils.data.Dataset): ...@@ -29,11 +29,11 @@ class RandomPerturbationFile(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
speed = self.rng.uniform(0.5, 2.0) speed = self.rng.uniform(0.5, 2.0)
effects = [ effects = [
['gain', '-n', '-10'], ["gain", "-n", "-10"],
['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds. ["speed", f"{speed:.5f}"], # duration of data is 0.5 ~ 2.0 seconds.
['rate', f'{self.sample_rate}'], ["rate", f"{self.sample_rate}"],
['pad', '0', '1.5'], # add 1.5 seconds silence at the end ["pad", "0", "1.5"], # add 1.5 seconds silence at the end
['trim', '0', '2'], # get the first 2 seconds ["trim", "0", "2"], # get the first 2 seconds
] ]
data, _ = torchaudio.sox_effects.apply_effects_file(self.flist[index], effects) data, _ = torchaudio.sox_effects.apply_effects_file(self.flist[index], effects)
return data return data
...@@ -44,6 +44,7 @@ class RandomPerturbationFile(torch.utils.data.Dataset): ...@@ -44,6 +44,7 @@ class RandomPerturbationFile(torch.utils.data.Dataset):
class RandomPerturbationTensor(torch.utils.data.Dataset): class RandomPerturbationTensor(torch.utils.data.Dataset):
"""Apply speed purturbation to (synthetic) Tensor data""" """Apply speed purturbation to (synthetic) Tensor data"""
def __init__(self, signals: List[Tuple[torch.Tensor, int]], sample_rate: int): def __init__(self, signals: List[Tuple[torch.Tensor, int]], sample_rate: int):
super().__init__() super().__init__()
self.signals = signals self.signals = signals
...@@ -53,11 +54,11 @@ class RandomPerturbationTensor(torch.utils.data.Dataset): ...@@ -53,11 +54,11 @@ class RandomPerturbationTensor(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
speed = self.rng.uniform(0.5, 2.0) speed = self.rng.uniform(0.5, 2.0)
effects = [ effects = [
['gain', '-n', '-10'], ["gain", "-n", "-10"],
['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds. ["speed", f"{speed:.5f}"], # duration of data is 0.5 ~ 2.0 seconds.
['rate', f'{self.sample_rate}'], ["rate", f"{self.sample_rate}"],
['pad', '0', '1.5'], # add 1.5 seconds silence at the end ["pad", "0", "1.5"], # add 1.5 seconds silence at the end
['trim', '0', '2'], # get the first 2 seconds ["trim", "0", "2"], # get the first 2 seconds
] ]
tensor, sample_rate = self.signals[index] tensor, sample_rate = self.signals[index]
data, _ = torchaudio.sox_effects.apply_effects_tensor(tensor, sample_rate, effects) data, _ = torchaudio.sox_effects.apply_effects_tensor(tensor, sample_rate, effects)
...@@ -74,11 +75,9 @@ def init_random_seed(worker_id): ...@@ -74,11 +75,9 @@ def init_random_seed(worker_id):
@skipIfNoSox @skipIfNoSox
@skipIf( @skipIf(
platform.system() == 'Darwin' and platform.system() == "Darwin" and sys.version_info.major == 3 and sys.version_info.minor in [6, 7],
sys.version_info.major == 3 and "This test is known to get stuck for macOS with Python < 3.8. "
sys.version_info.minor in [6, 7], "See https://github.com/pytorch/pytorch/issues/46409",
'This test is known to get stuck for macOS with Python < 3.8. '
'See https://github.com/pytorch/pytorch/issues/46409'
) )
class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
"""Test `apply_effects_file` in multi-process dataloader setting""" """Test `apply_effects_file` in multi-process dataloader setting"""
...@@ -87,9 +86,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -87,9 +86,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
flist = [] flist = []
for i in range(num_samples): for i in range(num_samples):
sample_rate = np.random.choice([8000, 16000, 44100]) sample_rate = np.random.choice([8000, 16000, 44100])
dtype = np.random.choice(['float32', 'int32', 'int16', 'uint8']) dtype = np.random.choice(["float32", "int32", "int16", "uint8"])
data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype=dtype) data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype=dtype)
path = self.get_temp_path(f'{i:03d}_{dtype}_{sample_rate}.wav') path = self.get_temp_path(f"{i:03d}_{dtype}_{sample_rate}.wav")
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
flist.append(path) flist.append(path)
return flist return flist
...@@ -99,7 +98,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -99,7 +98,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
flist = self._generate_dataset() flist = self._generate_dataset()
dataset = RandomPerturbationFile(flist, sample_rate) dataset = RandomPerturbationFile(flist, sample_rate)
loader = torch.utils.data.DataLoader( loader = torch.utils.data.DataLoader(
dataset, batch_size=32, num_workers=16, dataset,
batch_size=32,
num_workers=16,
worker_init_fn=init_random_seed, worker_init_fn=init_random_seed,
) )
for batch in loader: for batch in loader:
...@@ -109,8 +110,7 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -109,8 +110,7 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
signals = [] signals = []
for _ in range(num_samples): for _ in range(num_samples):
sample_rate = np.random.choice([8000, 16000, 44100]) sample_rate = np.random.choice([8000, 16000, 44100])
data = get_whitenoise( data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype="float32")
n_channels=2, sample_rate=sample_rate, duration=1, dtype='float32')
signals.append((data, sample_rate)) signals.append((data, sample_rate))
return signals return signals
...@@ -119,7 +119,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -119,7 +119,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
signals = self._generate_signals() signals = self._generate_signals()
dataset = RandomPerturbationTensor(signals, sample_rate) dataset = RandomPerturbationTensor(signals, sample_rate)
loader = torch.utils.data.DataLoader( loader = torch.utils.data.DataLoader(
dataset, batch_size=32, num_workers=16, dataset,
batch_size=32,
num_workers=16,
worker_init_fn=init_random_seed, worker_init_fn=init_random_seed,
) )
for batch in loader: for batch in loader:
...@@ -129,8 +131,8 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -129,8 +131,8 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
def speed(path): def speed(path):
wav, sample_rate = torchaudio.backend.sox_io_backend.load(path) wav, sample_rate = torchaudio.backend.sox_io_backend.load(path)
effects = [ effects = [
['speed', '1.03756523535464655'], ["speed", "1.03756523535464655"],
['rate', f'{sample_rate}'], ["rate", f"{sample_rate}"],
] ]
return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0] return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0]
...@@ -143,12 +145,12 @@ class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase): ...@@ -143,12 +145,12 @@ class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase):
sample_rate = 16000 sample_rate = 16000
self.flist = [] self.flist = []
for i in range(10): for i in range(10):
path = self.get_temp_path(f'{i}.wav') path = self.get_temp_path(f"{i}.wav")
data = get_whitenoise(n_channels=1, sample_rate=sample_rate, duration=1, dtype='float') data = get_whitenoise(n_channels=1, sample_rate=sample_rate, duration=1, dtype="float")
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
self.flist.append(path) self.flist.append(path)
@skipIf(os.environ.get("CI") == 'true', "This test now hangs in CI") @skipIf(os.environ.get("CI") == "true", "This test now hangs in CI")
def test_executor(self): def test_executor(self):
"""Test that apply_effects_tensor with speed + rate does not crush """Test that apply_effects_tensor with speed + rate does not crush
......
from torchaudio import sox_effects
from parameterized import parameterized from parameterized import parameterized
from torchaudio import sox_effects
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -9,6 +8,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,6 +8,7 @@ from torchaudio_unittest.common_utils import (
get_sinusoid, get_sinusoid,
save_wav, save_wav,
) )
from .common import ( from .common import (
load_params, load_params,
) )
...@@ -24,18 +24,17 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -24,18 +24,17 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
This test suite should be able to run without any additional tools (such as sox command), This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified. however without such tools, the correctness of each function cannot be verified.
""" """
@parameterized.expand( @parameterized.expand(
load_params("sox_effect_test_args.jsonl"), load_params("sox_effect_test_args.jsonl"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
) )
def test_apply_effects_tensor(self, args): def test_apply_effects_tensor(self, args):
"""`apply_effects_tensor` should not crash""" """`apply_effects_tensor` should not crash"""
effects = args['effects'] effects = args["effects"]
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
original = get_sinusoid( original = get_sinusoid(frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype="float32")
frequency=800, sample_rate=input_sr,
n_channels=num_channels, dtype='float32')
_found, _sr = sox_effects.apply_effects_tensor(original, input_sr, effects) _found, _sr = sox_effects.apply_effects_tensor(original, input_sr, effects)
@parameterized.expand( @parameterized.expand(
...@@ -44,18 +43,19 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -44,18 +43,19 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
) )
def test_apply_effects_file(self, args): def test_apply_effects_file(self, args):
"""`apply_effects_file` should return identical data as sox command""" """`apply_effects_file` should return identical data as sox command"""
dtype = 'int32' dtype = "int32"
channels_first = True channels_first = True
effects = args['effects'] effects = args["effects"]
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
input_path = self.get_temp_path('input.wav') input_path = self.get_temp_path("input.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first) data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first) save_wav(input_path, data, input_sr, channels_first=channels_first)
_found, _sr = sox_effects.apply_effects_file( _found, _sr = sox_effects.apply_effects_file(
input_path, effects, normalize=False, channels_first=channels_first) input_path, effects, normalize=False, channels_first=channels_first
)
@parameterized.expand( @parameterized.expand(
load_params("sox_effect_test_args.jsonl"), load_params("sox_effect_test_args.jsonl"),
...@@ -63,16 +63,17 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -63,16 +63,17 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
) )
def test_apply_effects_fileobj(self, args): def test_apply_effects_fileobj(self, args):
"""`apply_effects_file` should return identical data as sox command""" """`apply_effects_file` should return identical data as sox command"""
dtype = 'int32' dtype = "int32"
channels_first = True channels_first = True
effects = args['effects'] effects = args["effects"]
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
input_path = self.get_temp_path('input.wav') input_path = self.get_temp_path("input.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first) data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first) save_wav(input_path, data, input_sr, channels_first=channels_first)
with open(input_path, 'rb') as fileobj: with open(input_path, "rb") as fileobj:
_found, _sr = sox_effects.apply_effects_file( _found, _sr = sox_effects.apply_effects_file(
fileobj, effects, normalize=False, channels_first=channels_first) fileobj, effects, normalize=False, channels_first=channels_first
)
import io import io
import itertools import itertools
from pathlib import Path
import tarfile import tarfile
from pathlib import Path
from parameterized import parameterized from parameterized import parameterized
from torchaudio import sox_effects from torchaudio import sox_effects
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
HttpServerMixin, HttpServerMixin,
...@@ -21,6 +20,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -21,6 +20,7 @@ from torchaudio_unittest.common_utils import (
load_wav, load_wav,
sox_utils, sox_utils,
) )
from .common import ( from .common import (
load_params, load_params,
name_func, name_func,
...@@ -42,18 +42,16 @@ class TestSoxEffects(PytorchTestCase): ...@@ -42,18 +42,16 @@ class TestSoxEffects(PytorchTestCase):
@skipIfNoSox @skipIfNoSox
class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
"""Test suite for `apply_effects_tensor` function""" """Test suite for `apply_effects_tensor` function"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], @parameterized.expand(
[8000, 16000], list(itertools.product(["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2, 4, 8], [True, False])),
[1, 2, 4, 8], name_func=name_func,
[True, False] )
)), name_func=name_func)
def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first):
"""`apply_effects_tensor` without effects should return identical data as input""" """`apply_effects_tensor` without effects should return identical data as input"""
original = get_wav_data(dtype, num_channels, channels_first=channels_first) original = get_wav_data(dtype, num_channels, channels_first=channels_first)
expected = original.clone() expected = original.clone()
found, output_sample_rate = sox_effects.apply_effects_tensor( found, output_sample_rate = sox_effects.apply_effects_tensor(expected, sample_rate, [], channels_first)
expected, sample_rate, [], channels_first)
assert output_sample_rate == sample_rate assert output_sample_rate == sample_rate
# SoxEffect should not alter the input Tensor object # SoxEffect should not alter the input Tensor object
...@@ -69,20 +67,17 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): ...@@ -69,20 +67,17 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
) )
def test_apply_effects(self, args): def test_apply_effects(self, args):
"""`apply_effects_tensor` should return identical data as sox command""" """`apply_effects_tensor` should return identical data as sox command"""
effects = args['effects'] effects = args["effects"]
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
output_sr = args.get("output_sample_rate") output_sr = args.get("output_sample_rate")
input_path = self.get_temp_path('input.wav') input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
original = get_sinusoid( original = get_sinusoid(frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype="float32")
frequency=800, sample_rate=input_sr,
n_channels=num_channels, dtype='float32')
save_wav(input_path, original, input_sr) save_wav(input_path, original, input_sr)
sox_utils.run_sox_effect( sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr)
input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects) found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects)
...@@ -94,20 +89,27 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): ...@@ -94,20 +89,27 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
@skipIfNoSox @skipIfNoSox
class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
"""Test suite for `apply_effects_file` function""" """Test suite for `apply_effects_file` function"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], @parameterized.expand(
[8000, 16000], list(
[1, 2, 4, 8], itertools.product(
[False, True], ["float32", "int32", "int16", "uint8"],
)), name_func=name_func) [8000, 16000],
[1, 2, 4, 8],
[False, True],
)
),
name_func=name_func,
)
def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first):
"""`apply_effects_file` without effects should return identical data as input""" """`apply_effects_file` without effects should return identical data as input"""
path = self.get_temp_path('input.wav') path = self.get_temp_path("input.wav")
expected = get_wav_data(dtype, num_channels, channels_first=channels_first) expected = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(path, expected, sample_rate, channels_first=channels_first) save_wav(path, expected, sample_rate, channels_first=channels_first)
found, output_sample_rate = sox_effects.apply_effects_file( found, output_sample_rate = sox_effects.apply_effects_file(
path, [], normalize=False, channels_first=channels_first) path, [], normalize=False, channels_first=channels_first
)
assert output_sample_rate == sample_rate assert output_sample_rate == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
...@@ -118,46 +120,44 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): ...@@ -118,46 +120,44 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
) )
def test_apply_effects_str(self, args): def test_apply_effects_str(self, args):
"""`apply_effects_file` should return identical data as sox command""" """`apply_effects_file` should return identical data as sox command"""
dtype = 'int32' dtype = "int32"
channels_first = True channels_first = True
effects = args['effects'] effects = args["effects"]
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
output_sr = args.get("output_sample_rate") output_sr = args.get("output_sample_rate")
input_path = self.get_temp_path('input.wav') input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first) data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first) save_wav(input_path, data, input_sr, channels_first=channels_first)
sox_utils.run_sox_effect( sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr)
input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(input_path, effects, normalize=False, channels_first=channels_first)
input_path, effects, normalize=False, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
def test_apply_effects_path(self): def test_apply_effects_path(self):
"""`apply_effects_file` should return identical data as sox command when file path is given as a Path Object""" """`apply_effects_file` should return identical data as sox command when file path is given as a Path Object"""
dtype = 'int32' dtype = "int32"
channels_first = True channels_first = True
effects = [["hilbert"]] effects = [["hilbert"]]
num_channels = 2 num_channels = 2
input_sr = 8000 input_sr = 8000
output_sr = 8000 output_sr = 8000
input_path = self.get_temp_path('input.wav') input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first) data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first) save_wav(input_path, data, input_sr, channels_first=channels_first)
sox_utils.run_sox_effect( sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr)
input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(
Path(input_path), effects, normalize=False, channels_first=channels_first) Path(input_path), effects, normalize=False, channels_first=channels_first
)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
...@@ -166,91 +166,108 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): ...@@ -166,91 +166,108 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
@skipIfNoSox @skipIfNoSox
class TestFileFormats(TempDirMixin, PytorchTestCase): class TestFileFormats(TempDirMixin, PytorchTestCase):
"""`apply_effects_file` gives the same result as sox on various file formats""" """`apply_effects_file` gives the same result as sox on various file formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') ["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`apply_effects_file` works on various wav format""" """`apply_effects_file` works on various wav format"""
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
input_path = self.get_temp_path('input.wav') input_path = self.get_temp_path("input.wav")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first) data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, sample_rate, channels_first=channels_first) save_wav(input_path, data, sample_rate, channels_first=channels_first)
sox_utils.run_sox_effect(input_path, reference_path, effects) sox_utils.run_sox_effect(input_path, reference_path, effects)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(input_path, effects, normalize=False, channels_first=channels_first)
input_path, effects, normalize=False, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') [8000, 16000],
[1, 2],
)
),
name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
)
def test_mp3(self, sample_rate, num_channels): def test_mp3(self, sample_rate, num_channels):
"""`apply_effects_file` works on various mp3 format""" """`apply_effects_file` works on various mp3 format"""
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
input_path = self.get_temp_path('input.mp3') input_path = self.get_temp_path("input.mp3")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels) sox_utils.gen_audio_file(input_path, sample_rate, num_channels)
sox_utils.run_sox_effect(input_path, reference_path, effects) sox_utils.run_sox_effect(input_path, reference_path, effects)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first)
input_path, effects, channels_first=channels_first) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') [8000, 16000],
[1, 2],
)
),
name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
)
def test_flac(self, sample_rate, num_channels): def test_flac(self, sample_rate, num_channels):
"""`apply_effects_file` works on various flac format""" """`apply_effects_file` works on various flac format"""
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
input_path = self.get_temp_path('input.flac') input_path = self.get_temp_path("input.flac")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels) sox_utils.gen_audio_file(input_path, sample_rate, num_channels)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first)
input_path, effects, channels_first=channels_first) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000], list(
[1, 2], itertools.product(
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') [8000, 16000],
[1, 2],
)
),
name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}',
)
def test_vorbis(self, sample_rate, num_channels): def test_vorbis(self, sample_rate, num_channels):
"""`apply_effects_file` works on various vorbis format""" """`apply_effects_file` works on various vorbis format"""
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
input_path = self.get_temp_path('input.vorbis') input_path = self.get_temp_path("input.vorbis")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels) sox_utils.gen_audio_file(input_path, sample_rate, num_channels)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(input_path, effects, channels_first=channels_first)
input_path, effects, channels_first=channels_first) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
...@@ -268,156 +285,152 @@ class TestApplyEffectFileWithoutExtension(PytorchTestCase): ...@@ -268,156 +285,152 @@ class TestApplyEffectFileWithoutExtension(PytorchTestCase):
The file was generated with the following command The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
""" """
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
path = get_asset_path("mp3_without_ext") path = get_asset_path("mp3_without_ext")
_, sr = sox_effects.apply_effects_file(path, effects, format="mp3") _, sr = sox_effects.apply_effects_file(path, effects, format="mp3")
assert sr == 16000 assert sr == 16000
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoSox @skipIfNoSox
class TestFileObject(TempDirMixin, PytorchTestCase): class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([ @parameterized.expand(
('wav', None), [
('mp3', 128), ("wav", None),
('mp3', 320), ("mp3", 128),
('flac', 0), ("mp3", 320),
('flac', 5), ("flac", 0),
('flac', 8), ("flac", 5),
('vorbis', -1), ("flac", 8),
('vorbis', 10), ("vorbis", -1),
('amb', None), ("vorbis", 10),
]) ("amb", None),
]
)
def test_fileobj(self, ext, compression): def test_fileobj(self, ext, compression):
"""Applying effects via file object works""" """Applying effects via file object works"""
sample_rate = 16000 sample_rate = 16000
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
input_path = self.get_temp_path(f'input.{ext}') input_path = self.get_temp_path(f"input.{ext}")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
with open(input_path, 'rb') as fileobj: with open(input_path, "rb") as fileobj:
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first, format=format_)
fileobj, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand(
('wav', None), [
('mp3', 128), ("wav", None),
('mp3', 320), ("mp3", 128),
('flac', 0), ("mp3", 320),
('flac', 5), ("flac", 0),
('flac', 8), ("flac", 5),
('vorbis', -1), ("flac", 8),
('vorbis', 10), ("vorbis", -1),
('amb', None), ("vorbis", 10),
]) ("amb", None),
]
)
def test_bytesio(self, ext, compression): def test_bytesio(self, ext, compression):
"""Applying effects via BytesIO object works""" """Applying effects via BytesIO object works"""
sample_rate = 16000 sample_rate = 16000
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
input_path = self.get_temp_path(f'input.{ext}') input_path = self.get_temp_path(f"input.{ext}")
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
with open(input_path, 'rb') as file_: with open(input_path, "rb") as file_:
fileobj = io.BytesIO(file_.read()) fileobj = io.BytesIO(file_.read())
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first, format=format_)
fileobj, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand(
('wav', None), [
('mp3', 128), ("wav", None),
('mp3', 320), ("mp3", 128),
('flac', 0), ("mp3", 320),
('flac', 5), ("flac", 0),
('flac', 8), ("flac", 5),
('vorbis', -1), ("flac", 8),
('vorbis', 10), ("vorbis", -1),
('amb', None), ("vorbis", 10),
]) ("amb", None),
]
)
def test_tarfile(self, ext, compression): def test_tarfile(self, ext, compression):
"""Applying effects to compressed audio via file-like file works""" """Applying effects to compressed audio via file-like file works"""
sample_rate = 16000 sample_rate = 16000
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
audio_file = f'input.{ext}' audio_file = f"input.{ext}"
input_path = self.get_temp_path(audio_file) input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
archive_path = self.get_temp_path('archive.tar.gz') archive_path = self.get_temp_path("archive.tar.gz")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
with tarfile.TarFile(archive_path, 'w') as tarobj: with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(input_path, arcname=audio_file) tarobj.add(input_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj: with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file) fileobj = tarobj.extractfile(audio_file)
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first, format=format_)
fileobj, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
@skipIfNoModule("requests") @skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([ @parameterized.expand(
('wav', None), [
('mp3', 128), ("wav", None),
('mp3', 320), ("mp3", 128),
('flac', 0), ("mp3", 320),
('flac', 5), ("flac", 0),
('flac', 8), ("flac", 5),
('vorbis', -1), ("flac", 8),
('vorbis', 10), ("vorbis", -1),
('amb', None), ("vorbis", 10),
]) ("amb", None),
]
)
def test_requests(self, ext, compression): def test_requests(self, ext, compression):
sample_rate = 16000 sample_rate = 16000
channels_first = True channels_first = True
effects = [['band', '300', '10']] effects = [["band", "300", "10"]]
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ["mp3"] else None
audio_file = f'input.{ext}' audio_file = f"input.{ext}"
input_path = self.get_temp_path(audio_file) input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path('reference.wav') reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file( sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path) expected, expected_sr = load_wav(reference_path)
url = self.get_url(audio_file) url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp: with requests.get(url, stream=True) as resp:
found, sr = sox_effects.apply_effects_file( found, sr = sox_effects.apply_effects_file(resp.raw, effects, channels_first=channels_first, format=format_)
resp.raw, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
from typing import List from typing import List
import torch import torch
from torchaudio import sox_effects
from parameterized import parameterized from parameterized import parameterized
from torchaudio import sox_effects
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +11,7 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
torch_script, torch_script,
) )
from .common import ( from .common import (
load_params, load_params,
) )
...@@ -27,8 +27,7 @@ class SoxEffectTensorTransform(torch.nn.Module): ...@@ -27,8 +27,7 @@ class SoxEffectTensorTransform(torch.nn.Module):
self.channels_first = channels_first self.channels_first = channels_first
def forward(self, tensor: torch.Tensor): def forward(self, tensor: torch.Tensor):
return sox_effects.apply_effects_tensor( return sox_effects.apply_effects_tensor(tensor, self.sample_rate, self.effects, self.channels_first)
tensor, self.sample_rate, self.effects, self.channels_first)
class SoxEffectFileTransform(torch.nn.Module): class SoxEffectFileTransform(torch.nn.Module):
...@@ -51,7 +50,7 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase): ...@@ -51,7 +50,7 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase):
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
) )
def test_apply_effects_tensor(self, args): def test_apply_effects_tensor(self, args):
effects = args['effects'] effects = args["effects"]
channels_first = True channels_first = True
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
...@@ -61,11 +60,10 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase): ...@@ -61,11 +60,10 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase):
trans = torch_script(trans) trans = torch_script(trans)
wav = get_sinusoid( wav = get_sinusoid(
frequency=800, sample_rate=input_sr, frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype="float32", channels_first=channels_first
n_channels=num_channels, dtype='float32', channels_first=channels_first) )
found, sr_found = trans(wav) found, sr_found = trans(wav)
expected, sr_expected = sox_effects.apply_effects_tensor( expected, sr_expected = sox_effects.apply_effects_tensor(wav, input_sr, effects, channels_first)
wav, input_sr, effects, channels_first)
assert sr_found == sr_expected assert sr_found == sr_expected
self.assertEqual(expected, found) self.assertEqual(expected, found)
...@@ -75,7 +73,7 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase): ...@@ -75,7 +73,7 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase):
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
) )
def test_apply_effects_file(self, args): def test_apply_effects_file(self, args):
effects = args['effects'] effects = args["effects"]
channels_first = True channels_first = True
num_channels = args.get("num_channels", 2) num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000) input_sr = args.get("input_sample_rate", 8000)
...@@ -83,10 +81,10 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase): ...@@ -83,10 +81,10 @@ class TestTorchScript(TempDirMixin, TorchaudioTestCase):
trans = SoxEffectFileTransform(effects, channels_first) trans = SoxEffectFileTransform(effects, channels_first)
trans = torch_script(trans) trans = torch_script(trans)
path = self.get_temp_path('input.wav') path = self.get_temp_path("input.wav")
wav = get_sinusoid( wav = get_sinusoid(
frequency=800, sample_rate=input_sr, frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype="float32", channels_first=channels_first
n_channels=num_channels, dtype='float32', channels_first=channels_first) )
save_wav(path, wav, sample_rate=input_sr, channels_first=channels_first) save_wav(path, wav, sample_rate=input_sr, channels_first=channels_first)
found, sr_found = trans(path) found, sr_found = trans(path)
......
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32 from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32
class AutogradCPUTest(AutogradTestMixin, PytorchTestCase): class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
device = 'cpu' device = "cpu"
class AutogradRNNTCPUTest(AutogradTestFloat32, PytorchTestCase): class AutogradRNNTCPUTest(AutogradTestFloat32, PytorchTestCase):
device = 'cpu' device = "cpu"
...@@ -2,14 +2,15 @@ from torchaudio_unittest.common_utils import ( ...@@ -2,14 +2,15 @@ from torchaudio_unittest.common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoCuda, skipIfNoCuda,
) )
from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32 from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32
@skipIfNoCuda @skipIfNoCuda
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase): class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
device = 'cuda' device = "cuda"
@skipIfNoCuda @skipIfNoCuda
class AutogradRNNTCUDATest(AutogradTestFloat32, PytorchTestCase): class AutogradRNNTCUDATest(AutogradTestFloat32, PytorchTestCase):
device = 'cuda' device = "cuda"
from typing import List
import unittest import unittest
from typing import List
from parameterized import parameterized
import torch import torch
from torch.autograd import gradcheck, gradgradcheck
import torchaudio.transforms as T import torchaudio.transforms as T
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
get_whitenoise, get_whitenoise,
...@@ -17,6 +16,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -17,6 +16,7 @@ from torchaudio_unittest.common_utils import (
class _DeterministicWrapper(torch.nn.Module): class _DeterministicWrapper(torch.nn.Module):
"""Helper transform wrapper to make the given transform deterministic""" """Helper transform wrapper to make the given transform deterministic"""
def __init__(self, transform, seed=0): def __init__(self, transform, seed=0):
super().__init__() super().__init__()
self.seed = seed self.seed = seed
...@@ -29,11 +29,11 @@ class _DeterministicWrapper(torch.nn.Module): ...@@ -29,11 +29,11 @@ class _DeterministicWrapper(torch.nn.Module):
class AutogradTestMixin(TestBaseMixin): class AutogradTestMixin(TestBaseMixin):
def assert_grad( def assert_grad(
self, self,
transform: torch.nn.Module, transform: torch.nn.Module,
inputs: List[torch.Tensor], inputs: List[torch.Tensor],
*, *,
nondet_tol: float = 0.0, nondet_tol: float = 0.0,
): ):
transform = transform.to(dtype=torch.float64, device=self.device) transform = transform.to(dtype=torch.float64, device=self.device)
...@@ -42,32 +42,32 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -42,32 +42,32 @@ class AutogradTestMixin(TestBaseMixin):
inputs_ = [] inputs_ = []
for i in inputs: for i in inputs:
if torch.is_tensor(i): if torch.is_tensor(i):
i = i.to( i = i.to(dtype=torch.cdouble if i.is_complex() else torch.double, device=self.device)
dtype=torch.cdouble if i.is_complex() else torch.double,
device=self.device)
i.requires_grad = True i.requires_grad = True
inputs_.append(i) inputs_.append(i)
assert gradcheck(transform, inputs_) assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
@parameterized.expand([ @parameterized.expand(
({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ), [
({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ), ({"pad": 0, "normalized": False, "power": None, "return_complex": True},),
({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ), ({"pad": 3, "normalized": False, "power": None, "return_complex": True},),
({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ), ({"pad": 0, "normalized": True, "power": None, "return_complex": True},),
({'pad': 0, 'normalized': False, 'power': None}, ), ({"pad": 3, "normalized": True, "power": None, "return_complex": True},),
({'pad': 3, 'normalized': False, 'power': None}, ), ({"pad": 0, "normalized": False, "power": None},),
({'pad': 0, 'normalized': True, 'power': None}, ), ({"pad": 3, "normalized": False, "power": None},),
({'pad': 3, 'normalized': True, 'power': None}, ), ({"pad": 0, "normalized": True, "power": None},),
({'pad': 0, 'normalized': False, 'power': 1.0}, ), ({"pad": 3, "normalized": True, "power": None},),
({'pad': 3, 'normalized': False, 'power': 1.0}, ), ({"pad": 0, "normalized": False, "power": 1.0},),
({'pad': 0, 'normalized': True, 'power': 1.0}, ), ({"pad": 3, "normalized": False, "power": 1.0},),
({'pad': 3, 'normalized': True, 'power': 1.0}, ), ({"pad": 0, "normalized": True, "power": 1.0},),
({'pad': 0, 'normalized': False, 'power': 2.0}, ), ({"pad": 3, "normalized": True, "power": 1.0},),
({'pad': 3, 'normalized': False, 'power': 2.0}, ), ({"pad": 0, "normalized": False, "power": 2.0},),
({'pad': 0, 'normalized': True, 'power': 2.0}, ), ({"pad": 3, "normalized": False, "power": 2.0},),
({'pad': 3, 'normalized': True, 'power': 2.0}, ), ({"pad": 0, "normalized": True, "power": 2.0},),
]) ({"pad": 3, "normalized": True, "power": 2.0},),
]
)
def test_spectrogram(self, kwargs): def test_spectrogram(self, kwargs):
# replication_pad1d_backward_cuda is not deteministic and # replication_pad1d_backward_cuda is not deteministic and
# gives very small (~2.7756e-17) difference. # gives very small (~2.7756e-17) difference.
...@@ -105,21 +105,20 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -105,21 +105,20 @@ class AutogradTestMixin(TestBaseMixin):
power = 1 power = 1
n_iter = 2 n_iter = 2
spec = get_spectrogram( spec = get_spectrogram(get_whitenoise(sample_rate=8000, duration=0.01, n_channels=2), n_fft=n_fft, power=power)
get_whitenoise(sample_rate=8000, duration=0.01, n_channels=2),
n_fft=n_fft, power=power)
transform = _DeterministicWrapper( transform = _DeterministicWrapper(
T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power)) T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power)
)
self.assert_grad(transform, [spec]) self.assert_grad(transform, [spec])
@parameterized.expand([(False, ), (True, )]) @parameterized.expand([(False,), (True,)])
def test_mfcc(self, log_mels): def test_mfcc(self, log_mels):
sample_rate = 8000 sample_rate = 8000
transform = T.MFCC(sample_rate=sample_rate, log_mels=log_mels) transform = T.MFCC(sample_rate=sample_rate, log_mels=log_mels)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform]) self.assert_grad(transform, [waveform])
@parameterized.expand([(False, ), (True, )]) @parameterized.expand([(False,), (True,)])
def test_lfcc(self, log_lf): def test_lfcc(self, log_lf):
sample_rate = 8000 sample_rate = 8000
transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf) transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf)
...@@ -137,7 +136,7 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -137,7 +136,7 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform]) self.assert_grad(transform, [waveform])
@parameterized.expand([("linear", ), ("exponential", ), ("logarithmic", ), ("quarter_sine", ), ("half_sine", )]) @parameterized.expand([("linear",), ("exponential",), ("logarithmic",), ("quarter_sine",), ("half_sine",)])
def test_fade(self, fade_shape): def test_fade(self, fade_shape):
transform = T.Fade(fade_shape=fade_shape) transform = T.Fade(fade_shape=fade_shape)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
...@@ -148,8 +147,8 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -148,8 +147,8 @@ class AutogradTestMixin(TestBaseMixin):
sample_rate = 8000 sample_rate = 8000
n_fft = 400 n_fft = 400
spectrogram = get_spectrogram( spectrogram = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
n_fft=n_fft, power=1) )
deterministic_transform = _DeterministicWrapper(masking_transform(400)) deterministic_transform = _DeterministicWrapper(masking_transform(400))
self.assert_grad(deterministic_transform, [spectrogram]) self.assert_grad(deterministic_transform, [spectrogram])
...@@ -157,9 +156,10 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -157,9 +156,10 @@ class AutogradTestMixin(TestBaseMixin):
def test_masking_iid(self, masking_transform): def test_masking_iid(self, masking_transform):
sample_rate = 8000 sample_rate = 8000
n_fft = 400 n_fft = 400
specs = [get_spectrogram( specs = [
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i), get_spectrogram(
n_fft=n_fft, power=1) get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i), n_fft=n_fft, power=1
)
for i in range(3) for i in range(3)
] ]
...@@ -186,8 +186,8 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -186,8 +186,8 @@ class AutogradTestMixin(TestBaseMixin):
n_mels = n_fft // 2 + 1 n_mels = n_fft // 2 + 1
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels) transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels)
spec = get_spectrogram( spec = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
n_fft=n_fft, power=1) )
self.assert_grad(transform, [spec]) self.assert_grad(transform, [spec])
@parameterized.expand([(1.5, "amplitude"), (2, "power"), (10, "db")]) @parameterized.expand([(1.5, "amplitude"), (2, "power"), (10, "db")])
...@@ -197,18 +197,18 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -197,18 +197,18 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform]) self.assert_grad(transform, [waveform])
@parameterized.expand([ @parameterized.expand(
({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': False}, ), [
({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': False}, ), ({"cmn_window": 100, "min_cmn_window": 50, "center": False, "norm_vars": False},),
({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': True}, ), ({"cmn_window": 100, "min_cmn_window": 50, "center": True, "norm_vars": False},),
({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': True}, ), ({"cmn_window": 100, "min_cmn_window": 50, "center": False, "norm_vars": True},),
]) ({"cmn_window": 100, "min_cmn_window": 50, "center": True, "norm_vars": True},),
]
)
def test_sliding_window_cmn(self, kwargs): def test_sliding_window_cmn(self, kwargs):
n_fft = 10 n_fft = 10
power = 1 power = 1
spec = get_spectrogram( spec = get_spectrogram(get_whitenoise(sample_rate=200, duration=0.05, n_channels=2), n_fft=n_fft, power=power)
get_whitenoise(sample_rate=200, duration=0.05, n_channels=2),
n_fft=n_fft, power=power)
spec_reshaped = spec.transpose(-1, -2) spec_reshaped = spec.transpose(-1, -2)
transform = T.SlidingWindowCmn(**kwargs) transform = T.SlidingWindowCmn(**kwargs)
...@@ -260,10 +260,12 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -260,10 +260,12 @@ class AutogradTestMixin(TestBaseMixin):
spectrogram = get_spectrogram(waveform, n_fft=400) spectrogram = get_spectrogram(waveform, n_fft=400)
self.assert_grad(transform, [spectrogram]) self.assert_grad(transform, [spectrogram])
@parameterized.expand([ @parameterized.expand(
[True], [
[False], [True],
]) [False],
]
)
def test_psd_with_mask(self, multi_mask): def test_psd_with_mask(self, multi_mask):
transform = T.PSD(multi_mask=multi_mask) transform = T.PSD(multi_mask=multi_mask)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
...@@ -275,12 +277,14 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -275,12 +277,14 @@ class AutogradTestMixin(TestBaseMixin):
self.assert_grad(transform, [spectrogram, mask]) self.assert_grad(transform, [spectrogram, mask])
@parameterized.expand([ @parameterized.expand(
"ref_channel", [
# stv_power and stv_evd test time too long, comment for now "ref_channel",
# "stv_power", # stv_power and stv_evd test time too long, comment for now
# "stv_evd", # "stv_power",
]) # "stv_evd",
]
)
def test_mvdr(self, solution): def test_mvdr(self, solution):
transform = T.MVDR(solution=solution) transform = T.MVDR(solution=solution)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
...@@ -292,9 +296,9 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -292,9 +296,9 @@ class AutogradTestMixin(TestBaseMixin):
class AutogradTestFloat32(TestBaseMixin): class AutogradTestFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
self, self,
transform: torch.nn.Module, transform: torch.nn.Module,
inputs: List[torch.Tensor], inputs: List[torch.Tensor],
): ):
inputs_ = [] inputs_ = []
for i in inputs: for i in inputs:
...@@ -302,13 +306,15 @@ class AutogradTestFloat32(TestBaseMixin): ...@@ -302,13 +306,15 @@ class AutogradTestFloat32(TestBaseMixin):
i = i.to(dtype=torch.float32, device=self.device) i = i.to(dtype=torch.float32, device=self.device)
inputs_.append(i) inputs_.append(i)
# gradcheck with float32 requires higher atol and epsilon # gradcheck with float32 requires higher atol and epsilon
assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.) assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
@parameterized.expand([ @parameterized.expand(
(rnnt_utils.get_B1_T10_U3_D4_data, ), [
(rnnt_utils.get_B2_T4_U3_D3_data, ), (rnnt_utils.get_B1_T10_U3_D4_data,),
(rnnt_utils.get_B1_T2_U3_D5_data, ), (rnnt_utils.get_B2_T4_U3_D3_data,),
]) (rnnt_utils.get_B1_T2_U3_D5_data,),
]
)
def test_rnnt_loss(self, data_func): def test_rnnt_loss(self, data_func):
def get_data(data_func, device): def get_data(data_func, device):
data = data_func() data = data_func()
......
...@@ -2,25 +2,21 @@ ...@@ -2,25 +2,21 @@
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio import transforms as T from torchaudio import transforms as T
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
class TestTransforms(common_utils.TorchaudioTestCase): class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for classes defined in `transforms` module""" """Test suite for classes defined in `transforms` module"""
backend = 'default'
def assert_batch_consistency( backend = "default"
self, transform, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
**kwargs): def assert_batch_consistency(self, transform, batch, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
n = batch.size(0) n = batch.size(0)
# Compute items separately, then batch the result # Compute items separately, then batch the result
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
items_input = batch.clone() items_input = batch.clone()
items_result = torch.stack([ items_result = torch.stack([transform(items_input[i], *args, **kwargs) for i in range(n)])
transform(items_input[i], *args, **kwargs) for i in range(n)
])
# Batch the input and run # Batch the input and run
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -131,11 +127,7 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -131,11 +127,7 @@ class TestTransforms(common_utils.TorchaudioTestCase):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch)
spec = common_utils.get_spectrogram(tensor, n_fft=num_freq) spec = common_utils.get_spectrogram(tensor, n_fft=num_freq)
transform = T.TimeStretch( transform = T.TimeStretch(fixed_rate=rate, n_freq=num_freq // 2 + 1, hop_length=512)
fixed_rate=rate,
n_freq=num_freq // 2 + 1,
hop_length=512
)
self.assert_batch_consistency(transform, spec, atol=1e-5, rtol=1e-5) self.assert_batch_consistency(transform, spec, atol=1e-5, rtol=1e-5)
...@@ -197,10 +189,12 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -197,10 +189,12 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
@parameterized.expand([ @parameterized.expand(
[True], [
[False], [True],
]) [False],
]
)
def test_MVDR(self, multi_mask): def test_MVDR(self, multi_mask):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400) specgram = common_utils.get_spectrogram(waveform, n_fft=400)
......
import torch import torch
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .kaldi_compatibility_impl import Kaldi from .kaldi_compatibility_impl import Kaldi
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase): class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase): class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
import torch import torch
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .kaldi_compatibility_impl import Kaldi from .kaldi_compatibility_impl import Kaldi
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase): class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device("cuda")
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase): class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment