"sgl-kernel/csrc/vscode:/vscode.git/clone" did not exist on "c9bcffd2a53423e6a183e312a58675fb48435d2a"
Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
import os import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets import speechcommands
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -9,8 +10,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,8 +10,6 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
) )
from torchaudio.datasets import speechcommands
_LABELS = [ _LABELS = [
"bed", "bed",
"bird", "bird",
...@@ -93,10 +92,10 @@ def get_mock_dataset(dataset_dir): ...@@ -93,10 +92,10 @@ def get_mock_dataset(dataset_dir):
if j < 2: if j < 2:
mocked_train_samples.append(sample) mocked_train_samples.append(sample)
elif j < 4: elif j < 4:
valid.write(f'{label}/{filename}\n') valid.write(f"{label}/{filename}\n")
mocked_valid_samples.append(sample) mocked_valid_samples.append(sample)
elif j < 6: elif j < 6:
test.write(f'{label}/{filename}\n') test.write(f"{label}/{filename}\n")
mocked_test_samples.append(sample) mocked_test_samples.append(sample)
return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples
...@@ -113,16 +112,12 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase): ...@@ -113,16 +112,12 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir() cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join( dataset_dir = os.path.join(cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL)
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir) cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir)
def _testSpeechCommands(self, dataset, data_samples): def _testSpeechCommands(self, dataset, data_samples):
num_samples = 0 num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate( for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(dataset):
dataset
):
self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8) self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == data_samples[i][1] assert sample_rate == data_samples[i][1]
assert label == data_samples[i][2] assert label == data_samples[i][2]
......
...@@ -2,15 +2,8 @@ import os ...@@ -2,15 +2,8 @@ import os
import platform import platform
from pathlib import Path from pathlib import Path
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
skipIfNoSox
)
from torchaudio.datasets import tedlium from torchaudio.datasets import tedlium
from torchaudio_unittest.common_utils import TempDirMixin, TorchaudioTestCase, get_whitenoise, save_wav, skipIfNoSox
# Used to generate a unique utterance for each dummy audio file # Used to generate a unique utterance for each dummy audio file
_UTTERANCES = [ _UTTERANCES = [
...@@ -145,6 +138,7 @@ class TestTedliumSoundfile(Tedlium, TorchaudioTestCase): ...@@ -145,6 +138,7 @@ class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
if platform.system() != "Windows": if platform.system() != "Windows":
@skipIfNoSox @skipIfNoSox
class TestTedliumSoxIO(Tedlium, TorchaudioTestCase): class TestTedliumSoxIO(Tedlium, TorchaudioTestCase):
backend = "sox_io" backend = "sox_io"
...@@ -2,7 +2,6 @@ import os ...@@ -2,7 +2,6 @@ import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets import vctk from torchaudio.datasets import vctk
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -13,17 +12,17 @@ from torchaudio_unittest.common_utils import ( ...@@ -13,17 +12,17 @@ from torchaudio_unittest.common_utils import (
# Used to generate a unique transcript for each dummy audio file # Used to generate a unique transcript for each dummy audio file
_TRANSCRIPT = [ _TRANSCRIPT = [
'Please call Stella', "Please call Stella",
'Ask her to bring these things', "Ask her to bring these things",
'with her from the store', "with her from the store",
'Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob', "Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob",
'We also need a small plastic snake and a big toy frog for the kids', "We also need a small plastic snake and a big toy frog for the kids",
'She can scoop these things into three red bags, and we will go meet her Wednesday at the train station', "She can scoop these things into three red bags, and we will go meet her Wednesday at the train station",
'When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow', "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow",
'The rainbow is a division of white light into many beautiful colors', "The rainbow is a division of white light into many beautiful colors",
'These take the shape of a long round arch, with its path high above, and its two ends \ "These take the shape of a long round arch, with its path high above, and its two ends \
apparently beyond the horizon', apparently beyond the horizon",
'There is, according to legend, a boiling pot of gold at one end' "There is, according to legend, a boiling pot of gold at one end",
] ]
...@@ -32,51 +31,39 @@ def get_mock_dataset(root_dir): ...@@ -32,51 +31,39 @@ def get_mock_dataset(root_dir):
root_dir: root directory of the mocked data root_dir: root directory of the mocked data
""" """
mocked_samples = [] mocked_samples = []
dataset_dir = os.path.join(root_dir, 'VCTK-Corpus-0.92') dataset_dir = os.path.join(root_dir, "VCTK-Corpus-0.92")
os.makedirs(dataset_dir, exist_ok=True) os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 48000 sample_rate = 48000
seed = 0 seed = 0
for speaker in range(225, 230): for speaker in range(225, 230):
speaker_id = 'p' + str(speaker) speaker_id = "p" + str(speaker)
audio_dir = os.path.join(dataset_dir, 'wav48_silence_trimmed', speaker_id) audio_dir = os.path.join(dataset_dir, "wav48_silence_trimmed", speaker_id)
os.makedirs(audio_dir, exist_ok=True) os.makedirs(audio_dir, exist_ok=True)
file_dir = os.path.join(dataset_dir, 'txt', speaker_id) file_dir = os.path.join(dataset_dir, "txt", speaker_id)
os.makedirs(file_dir, exist_ok=True) os.makedirs(file_dir, exist_ok=True)
for utterance_id in range(1, 11): for utterance_id in range(1, 11):
filename = f'{speaker_id}_{utterance_id:03d}_mic2' filename = f"{speaker_id}_{utterance_id:03d}_mic2"
audio_file_path = os.path.join(audio_dir, filename + '.wav') audio_file_path = os.path.join(audio_dir, filename + ".wav")
data = get_whitenoise( data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype="float32", seed=seed)
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
dtype='float32',
seed=seed
)
save_wav(audio_file_path, data, sample_rate) save_wav(audio_file_path, data, sample_rate)
txt_file_path = os.path.join(file_dir, filename[:-5] + '.txt') txt_file_path = os.path.join(file_dir, filename[:-5] + ".txt")
transcript = _TRANSCRIPT[utterance_id - 1] transcript = _TRANSCRIPT[utterance_id - 1]
with open(txt_file_path, 'w') as f: with open(txt_file_path, "w") as f:
f.write(transcript) f.write(transcript)
sample = ( sample = (normalize_wav(data), sample_rate, transcript, speaker_id, utterance_id)
normalize_wav(data),
sample_rate,
transcript,
speaker_id,
utterance_id
)
mocked_samples.append(sample) mocked_samples.append(sample)
seed += 1 seed += 1
return mocked_samples return mocked_samples
class TestVCTK(TempDirMixin, TorchaudioTestCase): class TestVCTK(TempDirMixin, TorchaudioTestCase):
backend = 'default' backend = "default"
root_dir = None root_dir = None
samples = [] samples = []
......
...@@ -2,7 +2,6 @@ import os ...@@ -2,7 +2,6 @@ import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets import yesno from torchaudio.datasets import yesno
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -18,19 +17,19 @@ def get_mock_data(root_dir, labels): ...@@ -18,19 +17,19 @@ def get_mock_data(root_dir, labels):
labels: list of labels labels: list of labels
""" """
mocked_data = [] mocked_data = []
base_dir = os.path.join(root_dir, 'waves_yesno') base_dir = os.path.join(root_dir, "waves_yesno")
os.makedirs(base_dir, exist_ok=True) os.makedirs(base_dir, exist_ok=True)
for i, label in enumerate(labels): for i, label in enumerate(labels):
filename = f'{"_".join(str(l) for l in label)}.wav' filename = f'{"_".join(str(l) for l in label)}.wav'
path = os.path.join(base_dir, filename) path = os.path.join(base_dir, filename)
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16', seed=i) data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype="int16", seed=i)
save_wav(path, data, 8000) save_wav(path, data, 8000)
mocked_data.append(normalize_wav(data)) mocked_data.append(normalize_wav(data))
return mocked_data return mocked_data
class TestYesNo(TempDirMixin, TorchaudioTestCase): class TestYesNo(TempDirMixin, TorchaudioTestCase):
backend = 'default' backend = "default"
root_dir = None root_dir = None
data = [] data = []
......
...@@ -2,7 +2,4 @@ import os ...@@ -2,7 +2,4 @@ import os
import sys import sys
sys.path.append( sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples"))
os.path.join(
os.path.dirname(__file__),
'..', '..', '..', 'examples'))
from itertools import product from itertools import product
import torch import torch
from torch.testing._internal.common_utils import TestCase
from parameterized import parameterized from parameterized import parameterized
from source_separation.utils import metrics
from torch.testing._internal.common_utils import TestCase
from . import sdr_reference from . import sdr_reference
from source_separation.utils import metrics
class TestSDR(TestCase): class TestSDR(TestCase):
@parameterized.expand([(1, ), (2, ), (32, )]) @parameterized.expand([(1,), (2,), (32,)])
def test_sdr(self, batch_size): def test_sdr(self, batch_size):
"""sdr produces the same result as the reference implementation""" """sdr produces the same result as the reference implementation"""
num_frames = 256 num_frames = 256
......
"""Reference Implementation of SDR and PIT SDR. """Reference Implementation of SDR and PIT SDR.
This module was taken from the following implementation This module was taken from the following implementation
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py
which was made available by Yi Luo under the following liscence, which was made available by Yi Luo under the following liscence,
Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License. Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License.
The module was modified in the following manner; The module was modified in the following manner;
- Remove the functions other than `calc_sdr_torch` and `batch_SDR_torch`, - Remove the functions other than `calc_sdr_torch` and `batch_SDR_torch`,
- Remove the import statements required only for the removed functions. - Remove the import statements required only for the removed functions.
- Add `# flake8: noqa` so as not to report any format issue on this module. - Add `# flake8: noqa` so as not to report any format issue on this module.
The implementation of the retained functions and their formats are kept as-is. The implementation of the retained functions and their formats are kept as-is.
""" """
# flake8: noqa from itertools import permutations
import numpy as np # flake8: noqa
from itertools import permutations
import numpy as np
import torch import torch
def calc_sdr_torch(estimation, origin, mask=None): def calc_sdr_torch(estimation, origin, mask=None):
""" """
batch-wise SDR caculation for one audio file on pytorch Variables. batch-wise SDR caculation for one audio file on pytorch Variables.
estimation: (batch, nsample) estimation: (batch, nsample)
origin: (batch, nsample) origin: (batch, nsample)
mask: optional, (batch, nsample), binary mask: optional, (batch, nsample), binary
""" """
if mask is not None: if mask is not None:
origin = origin * mask origin = origin * mask
estimation = estimation * mask estimation = estimation * mask
origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1) origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1)
scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power # (batch, 1) scale = torch.sum(origin * estimation, 1, keepdim=True) / origin_power # (batch, 1)
est_true = scale * origin # (batch, nsample) est_true = scale * origin # (batch, nsample)
est_res = estimation - est_true # (batch, nsample) est_res = estimation - est_true # (batch, nsample)
true_power = torch.pow(est_true, 2).sum(1) true_power = torch.pow(est_true, 2).sum(1)
res_power = torch.pow(est_res, 2).sum(1) res_power = torch.pow(est_res, 2).sum(1)
return 10*torch.log10(true_power) - 10*torch.log10(res_power) # (batch, 1) return 10 * torch.log10(true_power) - 10 * torch.log10(res_power) # (batch, 1)
def batch_SDR_torch(estimation, origin, mask=None): def batch_SDR_torch(estimation, origin, mask=None):
""" """
batch-wise SDR caculation for multiple audio files. batch-wise SDR caculation for multiple audio files.
estimation: (batch, nsource, nsample) estimation: (batch, nsource, nsample)
origin: (batch, nsource, nsample) origin: (batch, nsource, nsample)
mask: optional, (batch, nsample), binary mask: optional, (batch, nsample), binary
""" """
batch_size_est, nsource_est, nsample_est = estimation.size() batch_size_est, nsource_est, nsample_est = estimation.size()
batch_size_ori, nsource_ori, nsample_ori = origin.size() batch_size_ori, nsource_ori, nsample_ori = origin.size()
assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape." assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape."
assert nsource_est == nsource_ori, "Estimation and original sources should have same shape." assert nsource_est == nsource_ori, "Estimation and original sources should have same shape."
assert nsample_est == nsample_ori, "Estimation and original sources should have same shape." assert nsample_est == nsample_ori, "Estimation and original sources should have same shape."
assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal." assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal."
batch_size = batch_size_est batch_size = batch_size_est
nsource = nsource_est nsource = nsource_est
nsample = nsample_est nsample = nsample_est
# zero mean signals # zero mean signals
estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation) estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation)
origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation) origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation)
# possible permutations # possible permutations
perm = list(set(permutations(np.arange(nsource)))) perm = list(set(permutations(np.arange(nsource))))
# pair-wise SDR # pair-wise SDR
SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type()) SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type())
for i in range(nsource): for i in range(nsource):
for j in range(nsource): for j in range(nsource):
SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask) SDR[:, i, j] = calc_sdr_torch(estimation[:, i], origin[:, j], mask)
# choose the best permutation # choose the best permutation
SDR_max = [] SDR_max = []
SDR_perm = [] SDR_perm = []
for permute in perm: for permute in perm:
sdr = [] sdr = []
for idx in range(len(permute)): for idx in range(len(permute)):
sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1)) sdr.append(SDR[:, idx, permute[idx]].view(batch_size, -1))
sdr = torch.sum(torch.cat(sdr, 1), 1) sdr = torch.sum(torch.cat(sdr, 1), 1)
SDR_perm.append(sdr.view(batch_size, 1)) SDR_perm.append(sdr.view(batch_size, 1))
SDR_perm = torch.cat(SDR_perm, 1) SDR_perm = torch.cat(SDR_perm, 1)
SDR_max, _ = torch.max(SDR_perm, dim=1) SDR_max, _ = torch.max(SDR_perm, dim=1)
return SDR_max / nsource return SDR_max / nsource
import os import os
from source_separation.utils.dataset import wsj0mix
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -8,8 +9,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -8,8 +9,6 @@ from torchaudio_unittest.common_utils import (
normalize_wav, normalize_wav,
) )
from source_separation.utils.dataset import wsj0mix
_FILENAMES = [ _FILENAMES = [
"012c0207_1.9952_01cc0202_-1.9952.wav", "012c0207_1.9952_01cc0202_-1.9952.wav",
...@@ -45,9 +44,7 @@ def _mock_dataset(root_dir, num_speaker): ...@@ -45,9 +44,7 @@ def _mock_dataset(root_dir, num_speaker):
mix = None mix = None
src = [] src = []
for dirname in dirnames: for dirname in dirnames:
waveform = get_whitenoise( waveform = get_whitenoise(sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed)
sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed
)
seed += 1 seed += 1
path = os.path.join(root_dir, dirname, filename) path = os.path.join(root_dir, dirname, filename)
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .tacotron2_loss_impl import ( from .tacotron2_loss_impl import (
Tacotron2LossShapeTests, Tacotron2LossShapeTests,
Tacotron2LossTorchscriptTests, Tacotron2LossTorchscriptTests,
Tacotron2LossGradcheckTests, Tacotron2LossGradcheckTests,
) )
from torchaudio_unittest.common_utils import PytorchTestCase
class TestTacotron2LossShapeFloat32CPU(Tacotron2LossShapeTests, PytorchTestCase): class TestTacotron2LossShapeFloat32CPU(Tacotron2LossShapeTests, PytorchTestCase):
...@@ -19,5 +19,5 @@ class TestTacotron2TorchsciptFloat32CPU(Tacotron2LossTorchscriptTests, PytorchTe ...@@ -19,5 +19,5 @@ class TestTacotron2TorchsciptFloat32CPU(Tacotron2LossTorchscriptTests, PytorchTe
class TestTacotron2GradcheckFloat64CPU(Tacotron2LossGradcheckTests, PytorchTestCase): class TestTacotron2GradcheckFloat64CPU(Tacotron2LossGradcheckTests, PytorchTestCase):
dtype = torch.float64 # gradcheck needs a higher numerical accuracy dtype = torch.float64 # gradcheck needs a higher numerical accuracy
device = torch.device("cpu") device = torch.device("cpu")
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .tacotron2_loss_impl import ( from .tacotron2_loss_impl import (
Tacotron2LossShapeTests, Tacotron2LossShapeTests,
Tacotron2LossTorchscriptTests, Tacotron2LossTorchscriptTests,
Tacotron2LossGradcheckTests, Tacotron2LossGradcheckTests,
) )
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
@skipIfNoCuda @skipIfNoCuda
...@@ -22,5 +22,5 @@ class TestTacotron2TorchsciptFloat32CUDA(PytorchTestCase, Tacotron2LossTorchscri ...@@ -22,5 +22,5 @@ class TestTacotron2TorchsciptFloat32CUDA(PytorchTestCase, Tacotron2LossTorchscri
@skipIfNoCuda @skipIfNoCuda
class TestTacotron2GradcheckFloat64CUDA(PytorchTestCase, Tacotron2LossGradcheckTests): class TestTacotron2GradcheckFloat64CUDA(PytorchTestCase, Tacotron2LossGradcheckTests):
dtype = torch.float64 # gradcheck needs a higher numerical accuracy dtype = torch.float64 # gradcheck needs a higher numerical accuracy
device = torch.device("cuda") device = torch.device("cuda")
import torch import torch
from torch.autograd import gradcheck, gradgradcheck
from pipeline_tacotron2.loss import Tacotron2Loss from pipeline_tacotron2.loss import Tacotron2Loss
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
torch_script, torch_script,
...@@ -9,18 +8,11 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,18 +8,11 @@ from torchaudio_unittest.common_utils import (
class Tacotron2LossInputMixin(TestBaseMixin): class Tacotron2LossInputMixin(TestBaseMixin):
def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300): def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
mel_specgram = torch.rand( mel_specgram = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device mel_specgram_postnet = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
)
mel_specgram_postnet = torch.rand(
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
)
gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device) gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
truth_mel_specgram = torch.rand( truth_mel_specgram = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device
)
truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device) truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
truth_mel_specgram.requires_grad = False truth_mel_specgram.requires_grad = False
...@@ -36,7 +28,6 @@ class Tacotron2LossInputMixin(TestBaseMixin): ...@@ -36,7 +28,6 @@ class Tacotron2LossInputMixin(TestBaseMixin):
class Tacotron2LossShapeTests(Tacotron2LossInputMixin): class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
def test_tacotron2_loss_shape(self): def test_tacotron2_loss_shape(self):
"""Validate the output shape of Tacotron2Loss.""" """Validate the output shape of Tacotron2Loss."""
n_batch = 16 n_batch = 16
...@@ -50,8 +41,7 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin): ...@@ -50,8 +41,7 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
) = self._get_inputs(n_batch=n_batch) ) = self._get_inputs(n_batch=n_batch)
mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()( mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()(
(mel_specgram, mel_specgram_postnet, gate_out), (mel_specgram, mel_specgram_postnet, gate_out), (truth_mel_specgram, truth_gate_out)
(truth_mel_specgram, truth_gate_out)
) )
self.assertEqual(mel_loss.size(), torch.Size([])) self.assertEqual(mel_loss.size(), torch.Size([]))
...@@ -60,7 +50,6 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin): ...@@ -60,7 +50,6 @@ class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin): class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
def _assert_torchscript_consistency(self, fn, tensors): def _assert_torchscript_consistency(self, fn, tensors):
ts_func = torch_script(fn) ts_func = torch_script(fn)
...@@ -77,7 +66,6 @@ class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin): ...@@ -77,7 +66,6 @@ class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin): class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin):
def test_tacotron2_loss_gradcheck(self): def test_tacotron2_loss_gradcheck(self):
"""Performing gradient check on Tacotron2Loss.""" """Performing gradient check on Tacotron2Loss."""
( (
......
from parameterized import parameterized from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
if is_module_available("unidecode") and is_module_available("inflect"): if is_module_available("unidecode") and is_module_available("inflect"):
from pipeline_tacotron2.text.text_preprocessing import text_to_sequence
from pipeline_tacotron2.text.numbers import ( from pipeline_tacotron2.text.numbers import (
_remove_commas, _remove_commas,
_expand_pounds, _expand_pounds,
...@@ -13,25 +11,62 @@ if is_module_available("unidecode") and is_module_available("inflect"): ...@@ -13,25 +11,62 @@ if is_module_available("unidecode") and is_module_available("inflect"):
_expand_ordinal, _expand_ordinal,
_expand_number, _expand_number,
) )
from pipeline_tacotron2.text.text_preprocessing import text_to_sequence
@skipIfNoModule("unidecode") @skipIfNoModule("unidecode")
@skipIfNoModule("inflect") @skipIfNoModule("inflect")
class TestTextPreprocessor(TorchaudioTestCase): class TestTextPreprocessor(TorchaudioTestCase):
@parameterized.expand( @parameterized.expand(
[ [
["dr. Strange?", [15, 26, 14, 31, 26, 29, 11, 30, 31, 29, 12, 25, 18, 16, 10]], ["dr. Strange?", [15, 26, 14, 31, 26, 29, 11, 30, 31, 29, 12, 25, 18, 16, 10]],
["ML, is fun.", [24, 23, 6, 11, 20, 30, 11, 17, 32, 25, 7]], ["ML, is fun.", [24, 23, 6, 11, 20, 30, 11, 17, 32, 25, 7]],
["I love torchaudio!", [20, 11, 23, 26, 33, 16, 11, 31, 26, 29, 14, 19, 12, 32, 15, 20, 26, 2]], ["I love torchaudio!", [20, 11, 23, 26, 33, 16, 11, 31, 26, 29, 14, 19, 12, 32, 15, 20, 26, 2]],
# 'one thousand dollars, twenty cents' # 'one thousand dollars, twenty cents'
["$1,000.20", [26, 25, 16, 11, 31, 19, 26, 32, 30, 12, 25, 15, 11, 15, 26, 23, 23, [
12, 29, 30, 6, 11, 31, 34, 16, 25, 31, 36, 11, 14, 16, 25, 31, 30]], "$1,000.20",
[
26,
25,
16,
11,
31,
19,
26,
32,
30,
12,
25,
15,
11,
15,
26,
23,
23,
12,
29,
30,
6,
11,
31,
34,
16,
25,
31,
36,
11,
14,
16,
25,
31,
30,
],
],
] ]
) )
def test_text_to_sequence(self, sent, seq): def test_text_to_sequence(self, sent, seq):
assert (text_to_sequence(sent) == seq) assert text_to_sequence(sent) == seq
@parameterized.expand( @parameterized.expand(
[ [
...@@ -40,7 +75,7 @@ class TestTextPreprocessor(TorchaudioTestCase): ...@@ -40,7 +75,7 @@ class TestTextPreprocessor(TorchaudioTestCase):
) )
def test_remove_commas(self, sent, truth): def test_remove_commas(self, sent, truth):
assert (_remove_commas(sent) == truth) assert _remove_commas(sent) == truth
@parameterized.expand( @parameterized.expand(
[ [
...@@ -49,19 +84,21 @@ class TestTextPreprocessor(TorchaudioTestCase): ...@@ -49,19 +84,21 @@ class TestTextPreprocessor(TorchaudioTestCase):
) )
def test_expand_pounds(self, sent, truth): def test_expand_pounds(self, sent, truth):
assert (_expand_pounds(sent) == truth) assert _expand_pounds(sent) == truth
@parameterized.expand( @parameterized.expand(
[ [
["He, she, and I have $1000", "He, she, and I have 1000 dollars"], ["He, she, and I have $1000", "He, she, and I have 1000 dollars"],
["He, she, and I have $3000.01", "He, she, and I have 3000 dollars, 1 cent"], ["He, she, and I have $3000.01", "He, she, and I have 3000 dollars, 1 cent"],
["He has $500.20 and she has $1000.50.", [
"He has 500 dollars, 20 cents and she has 1000 dollars, 50 cents."], "He has $500.20 and she has $1000.50.",
"He has 500 dollars, 20 cents and she has 1000 dollars, 50 cents.",
],
] ]
) )
def test_expand_dollars(self, sent, truth): def test_expand_dollars(self, sent, truth):
assert (_expand_dollars(sent) == truth) assert _expand_dollars(sent) == truth
@parameterized.expand( @parameterized.expand(
[ [
...@@ -71,7 +108,7 @@ class TestTextPreprocessor(TorchaudioTestCase): ...@@ -71,7 +108,7 @@ class TestTextPreprocessor(TorchaudioTestCase):
) )
def test_expand_decimal_point(self, sent, truth): def test_expand_decimal_point(self, sent, truth):
assert (_expand_decimal_point(sent) == truth) assert _expand_decimal_point(sent) == truth
@parameterized.expand( @parameterized.expand(
[ [
...@@ -82,16 +119,19 @@ class TestTextPreprocessor(TorchaudioTestCase): ...@@ -82,16 +119,19 @@ class TestTextPreprocessor(TorchaudioTestCase):
) )
def test_expand_ordinal(self, sent, truth): def test_expand_ordinal(self, sent, truth):
assert (_expand_ordinal(sent) == truth) assert _expand_ordinal(sent) == truth
_expand_ordinal, _expand_ordinal,
@parameterized.expand( @parameterized.expand(
[ [
["100020 dollars.", "one hundred thousand twenty dollars."], ["100020 dollars.", "one hundred thousand twenty dollars."],
["1234567890!", "one billion, two hundred thirty-four million, " [
"five hundred sixty-seven thousand, eight hundred ninety!"], "1234567890!",
"one billion, two hundred thirty-four million, "
"five hundred sixty-seven thousand, eight hundred ninety!",
],
] ]
) )
def test_expand_number(self, sent, truth): def test_expand_number(self, sent, truth):
assert (_expand_number(sent) == truth) assert _expand_number(sent) == truth
import torch import torch
from .autograd_impl import Autograd, AutogradFloat32
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .autograd_impl import Autograd, AutogradFloat32
class TestAutogradLfilterCPU(Autograd, common_utils.PytorchTestCase): class TestAutogradLfilterCPU(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
class TestAutogradRNNTCPU(AutogradFloat32, common_utils.PytorchTestCase): class TestAutogradRNNTCPU(AutogradFloat32, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
import torch import torch
from .autograd_impl import Autograd, AutogradFloat32
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .autograd_impl import Autograd, AutogradFloat32
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase): class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestAutogradRNNTCUDA(AutogradFloat32, common_utils.PytorchTestCase): class TestAutogradRNNTCUDA(AutogradFloat32, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device("cuda")
from typing import Callable, Tuple
from functools import partial from functools import partial
from typing import Callable, Tuple
import torch import torch
import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
from torch import Tensor from torch import Tensor
import torchaudio.functional as F
from torch.autograd import gradcheck, gradgradcheck from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
...@@ -14,11 +15,11 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,11 +15,11 @@ from torchaudio_unittest.common_utils import (
class Autograd(TestBaseMixin): class Autograd(TestBaseMixin):
def assert_grad( def assert_grad(
self, self,
transform: Callable[..., Tensor], transform: Callable[..., Tensor],
inputs: Tuple[torch.Tensor], inputs: Tuple[torch.Tensor],
*, *,
enable_all_grad: bool = True, enable_all_grad: bool = True,
): ):
inputs_ = [] inputs_ = []
for i in inputs: for i in inputs:
...@@ -64,19 +65,15 @@ class Autograd(TestBaseMixin): ...@@ -64,19 +65,15 @@ class Autograd(TestBaseMixin):
def test_lfilter_filterbanks(self): def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6], a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
[0.8, 0.2, 0.9]]) b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(partial(F.lfilter, batching=False), (x, a, b)) self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))
def test_lfilter_batching(self): def test_lfilter_batching(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6], a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
[0.8, 0.2, 0.9]]) b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b)) self.assert_grad(F.lfilter, (x, a, b))
def test_filtfilt_a(self): def test_filtfilt_a(self):
...@@ -105,10 +102,8 @@ class Autograd(TestBaseMixin): ...@@ -105,10 +102,8 @@ class Autograd(TestBaseMixin):
def test_filtfilt_batching(self): def test_filtfilt_batching(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6], a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
[0.8, 0.2, 0.9]]) b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.filtfilt, (x, a, b)) self.assert_grad(F.filtfilt, (x, a, b))
def test_biquad(self): def test_biquad(self):
...@@ -118,10 +113,12 @@ class Autograd(TestBaseMixin): ...@@ -118,10 +113,12 @@ class Autograd(TestBaseMixin):
b = torch.tensor([0.4, 0.2, 0.9]) b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2])) self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, True), [
(800, 0.7, False), (800, 0.7, True),
]) (800, 0.7, False),
]
)
def test_band_biquad(self, central_freq, Q, noise): def test_band_biquad(self, central_freq, Q, noise):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -130,10 +127,12 @@ class Autograd(TestBaseMixin): ...@@ -130,10 +127,12 @@ class Autograd(TestBaseMixin):
Q = torch.tensor(Q) Q = torch.tensor(Q)
self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise)) self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, 10), [
(800, 0.7, -10), (800, 0.7, 10),
]) (800, 0.7, -10),
]
)
def test_bass_biquad(self, central_freq, Q, gain): def test_bass_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -143,11 +142,12 @@ class Autograd(TestBaseMixin): ...@@ -143,11 +142,12 @@ class Autograd(TestBaseMixin):
gain = torch.tensor(gain) gain = torch.tensor(gain)
self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q)) self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
@parameterized.expand([ @parameterized.expand(
(3000, 0.7, 10), [
(3000, 0.7, -10), (3000, 0.7, 10),
(3000, 0.7, -10),
]) ]
)
def test_treble_biquad(self, central_freq, Q, gain): def test_treble_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -157,9 +157,14 @@ class Autograd(TestBaseMixin): ...@@ -157,9 +157,14 @@ class Autograd(TestBaseMixin):
gain = torch.tensor(gain) gain = torch.tensor(gain)
self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q)) self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, ), [
]) (
800,
0.7,
),
]
)
def test_allpass_biquad(self, central_freq, Q): def test_allpass_biquad(self, central_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -168,9 +173,14 @@ class Autograd(TestBaseMixin): ...@@ -168,9 +173,14 @@ class Autograd(TestBaseMixin):
Q = torch.tensor(Q) Q = torch.tensor(Q)
self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q)) self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, ), [
]) (
800,
0.7,
),
]
)
def test_lowpass_biquad(self, cutoff_freq, Q): def test_lowpass_biquad(self, cutoff_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -179,9 +189,14 @@ class Autograd(TestBaseMixin): ...@@ -179,9 +189,14 @@ class Autograd(TestBaseMixin):
Q = torch.tensor(Q) Q = torch.tensor(Q)
self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, ), [
]) (
800,
0.7,
),
]
)
def test_highpass_biquad(self, cutoff_freq, Q): def test_highpass_biquad(self, cutoff_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -190,10 +205,12 @@ class Autograd(TestBaseMixin): ...@@ -190,10 +205,12 @@ class Autograd(TestBaseMixin):
Q = torch.tensor(Q) Q = torch.tensor(Q)
self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q)) self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, True), [
(800, 0.7, False), (800, 0.7, True),
]) (800, 0.7, False),
]
)
def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -202,10 +219,12 @@ class Autograd(TestBaseMixin): ...@@ -202,10 +219,12 @@ class Autograd(TestBaseMixin):
Q = torch.tensor(Q) Q = torch.tensor(Q)
self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain)) self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, 10), [
(800, 0.7, -10), (800, 0.7, 10),
]) (800, 0.7, -10),
]
)
def test_equalizer_biquad(self, central_freq, Q, gain): def test_equalizer_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -215,9 +234,14 @@ class Autograd(TestBaseMixin): ...@@ -215,9 +234,14 @@ class Autograd(TestBaseMixin):
gain = torch.tensor(gain) gain = torch.tensor(gain)
self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q)) self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
@parameterized.expand([ @parameterized.expand(
(800, 0.7, ), [
]) (
800,
0.7,
),
]
)
def test_bandreject_biquad(self, central_freq, Q): def test_bandreject_biquad(self, central_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
...@@ -229,10 +253,10 @@ class Autograd(TestBaseMixin): ...@@ -229,10 +253,10 @@ class Autograd(TestBaseMixin):
class AutogradFloat32(TestBaseMixin): class AutogradFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
self, self,
transform: Callable[..., Tensor], transform: Callable[..., Tensor],
inputs: Tuple[torch.Tensor], inputs: Tuple[torch.Tensor],
enable_all_grad: bool = True, enable_all_grad: bool = True,
): ):
inputs_ = [] inputs_ = []
for i in inputs: for i in inputs:
...@@ -242,13 +266,15 @@ class AutogradFloat32(TestBaseMixin): ...@@ -242,13 +266,15 @@ class AutogradFloat32(TestBaseMixin):
i.requires_grad = True i.requires_grad = True
inputs_.append(i) inputs_.append(i)
# gradcheck with float32 requires higher atol and epsilon # gradcheck with float32 requires higher atol and epsilon
assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.) assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
@parameterized.expand([ @parameterized.expand(
(rnnt_utils.get_B1_T10_U3_D4_data, ), [
(rnnt_utils.get_B2_T4_U3_D3_data, ), (rnnt_utils.get_B1_T10_U3_D4_data,),
(rnnt_utils.get_B1_T2_U3_D5_data, ), (rnnt_utils.get_B2_T4_U3_D3_data,),
]) (rnnt_utils.get_B1_T2_U3_D5_data,),
]
)
def test_rnnt_loss(self, data_func): def test_rnnt_loss(self, data_func):
def get_data(data_func, device): def get_data(data_func, device):
data = data_func() data = data_func()
...@@ -259,11 +285,11 @@ class AutogradFloat32(TestBaseMixin): ...@@ -259,11 +285,11 @@ class AutogradFloat32(TestBaseMixin):
data = get_data(data_func, self.device) data = get_data(data_func, self.device)
inputs = ( inputs = (
data["logits"].to(torch.float32), # logits data["logits"].to(torch.float32), # logits
data["targets"], # targets data["targets"], # targets
data["logit_lengths"], # logit_lengths data["logit_lengths"], # logit_lengths
data["target_lengths"], # target_lengths data["target_lengths"], # target_lengths
data["blank"], # blank data["blank"], # blank
-1, # clamp -1, # clamp
) )
self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False) self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)
...@@ -2,41 +2,37 @@ ...@@ -2,41 +2,37 @@
import itertools import itertools
import math import math
from parameterized import parameterized, parameterized_class
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized, parameterized_class
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
def _name_from_args(func, _, params): def _name_from_args(func, _, params):
"""Return a parameterized test name, based on parameter values.""" """Return a parameterized test name, based on parameter values."""
return "{}_{}".format( return "{}_{}".format(func.__name__, "_".join(str(arg) for arg in params.args))
func.__name__,
"_".join(str(arg) for arg in params.args))
@parameterized_class([ @parameterized_class(
# Single-item batch isolates problems that come purely from adding a [
# dimension (rather than processing multiple items) # Single-item batch isolates problems that come purely from adding a
{"batch_size": 1}, # dimension (rather than processing multiple items)
{"batch_size": 3}, {"batch_size": 1},
]) {"batch_size": 3},
]
)
class TestFunctional(common_utils.TorchaudioTestCase): class TestFunctional(common_utils.TorchaudioTestCase):
"""Test functions defined in `functional` module""" """Test functions defined in `functional` module"""
backend = 'default'
def assert_batch_consistency( backend = "default"
self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
**kwargs): def assert_batch_consistency(self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
n = batch.size(0) n = batch.size(0)
# Compute items separately, then batch the result # Compute items separately, then batch the result
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
items_input = batch.clone() items_input = batch.clone()
items_result = torch.stack([ items_result = torch.stack([functional(items_input[i], *args, **kwargs) for i in range(n)])
functional(items_input[i], *args, **kwargs) for i in range(n)
])
# Batch the input and run # Batch the input and run
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -58,43 +54,45 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -58,43 +54,45 @@ class TestFunctional(common_utils.TorchaudioTestCase):
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6) batch = torch.rand(self.batch_size, 1, 201, 6)
self.assert_batch_consistency( self.assert_batch_consistency(
F.griffinlim, batch, window, n_fft, hop, ws, power, F.griffinlim, batch, window, n_fft, hop, ws, power, n_iter, momentum, length, 0, atol=5e-5
n_iter, momentum, length, 0, atol=5e-5) )
@parameterized.expand(list(itertools.product( @parameterized.expand(
[8000, 16000, 44100], list(
[1, 2], itertools.product(
)), name_func=_name_from_args) [8000, 16000, 44100],
[1, 2],
)
),
name_func=_name_from_args,
)
def test_detect_pitch_frequency(self, sample_rate, n_channels): def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a # Use different frequencies to ensure each item in the batch returns a
# different answer. # different answer.
torch.manual_seed(0) torch.manual_seed(0)
frequencies = torch.randint(100, 1000, [self.batch_size]) frequencies = torch.randint(100, 1000, [self.batch_size])
waveforms = torch.stack([ waveforms = torch.stack(
common_utils.get_sinusoid( [
frequency=frequency, sample_rate=sample_rate, common_utils.get_sinusoid(
n_channels=n_channels, duration=5) frequency=frequency, sample_rate=sample_rate, n_channels=n_channels, duration=5
for frequency in frequencies )
]) for frequency in frequencies
self.assert_batch_consistency( ]
F.detect_pitch_frequency, waveforms, sample_rate) )
self.assert_batch_consistency(F.detect_pitch_frequency, waveforms, sample_rate)
def test_amplitude_to_DB(self): def test_amplitude_to_DB(self):
torch.manual_seed(0) torch.manual_seed(0)
spec = torch.rand(self.batch_size, 2, 100, 100) * 200 spec = torch.rand(self.batch_size, 2, 100, 100) * 200
amplitude_mult = 20. amplitude_mult = 20.0
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
# Test with & without a `top_db` clamp # Test with & without a `top_db` clamp
self.assert_batch_consistency( self.assert_batch_consistency(F.amplitude_to_DB, spec, amplitude_mult, amin, db_mult, top_db=None)
F.amplitude_to_DB, spec, amplitude_mult, self.assert_batch_consistency(F.amplitude_to_DB, spec, amplitude_mult, amin, db_mult, top_db=40.0)
amin, db_mult, top_db=None)
self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.)
def test_amplitude_to_DB_itemwise_clamps(self): def test_amplitude_to_DB_itemwise_clamps(self):
"""Ensure that the clamps are separate for each spectrogram in a batch. """Ensure that the clamps are separate for each spectrogram in a batch.
...@@ -106,11 +104,11 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -106,11 +104,11 @@ class TestFunctional(common_utils.TorchaudioTestCase):
https://github.com/pytorch/audio/issues/994 https://github.com/pytorch/audio/issues/994
""" """
amplitude_mult = 20. amplitude_mult = 20.0
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
top_db = 20. top_db = 20.0
# Make a batch of noise # Make a batch of noise
torch.manual_seed(0) torch.manual_seed(0)
...@@ -118,36 +116,30 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -118,36 +116,30 @@ class TestFunctional(common_utils.TorchaudioTestCase):
# Make one item blow out the other # Make one item blow out the other
spec[0] += 50 spec[0] += 50
batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=top_db)
db_mult, top_db=top_db) itemwise_dbs = torch.stack(
itemwise_dbs = torch.stack([ [F.amplitude_to_DB(item, amplitude_mult, amin, db_mult, top_db=top_db) for item in spec]
F.amplitude_to_DB(item, amplitude_mult, amin, )
db_mult, top_db=top_db)
for item in spec
])
self.assertEqual(batchwise_dbs, itemwise_dbs) self.assertEqual(batchwise_dbs, itemwise_dbs)
def test_amplitude_to_DB_not_channelwise_clamps(self): def test_amplitude_to_DB_not_channelwise_clamps(self):
"""Check that clamps are applied per-item, not per channel.""" """Check that clamps are applied per-item, not per channel."""
amplitude_mult = 20. amplitude_mult = 20.0
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
top_db = 40. top_db = 40.0
torch.manual_seed(0) torch.manual_seed(0)
spec = torch.rand([1, 2, 100, 100]) * 200 spec = torch.rand([1, 2, 100, 100]) * 200
# Make one channel blow out the other # Make one channel blow out the other
spec[:, 0] += 50 spec[:, 0] += 50
specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=top_db)
db_mult, top_db=top_db) channelwise_dbs = torch.stack(
channelwise_dbs = torch.stack([ [F.amplitude_to_DB(spec[:, i], amplitude_mult, amin, db_mult, top_db=top_db) for i in range(spec.size(-3))]
F.amplitude_to_DB(spec[:, i], amplitude_mult, amin, )
db_mult, top_db=top_db)
for i in range(spec.size(-3))
])
# Just check channelwise gives a different answer. # Just check channelwise gives a different answer.
difference = (specwise_dbs - channelwise_dbs).abs() difference = (specwise_dbs - channelwise_dbs).abs()
...@@ -156,27 +148,24 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -156,27 +148,24 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_contrast(self): def test_contrast(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency( self.assert_batch_consistency(F.contrast, waveforms, enhancement_amount=80.0)
F.contrast, waveforms, enhancement_amount=80.)
def test_dcshift(self): def test_dcshift(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency( self.assert_batch_consistency(F.dcshift, waveforms, shift=0.5, limiter_gain=0.05)
F.dcshift, waveforms, shift=0.5, limiter_gain=0.05)
def test_overdrive(self): def test_overdrive(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency( self.assert_batch_consistency(F.overdrive, waveforms, gain=45, colour=30)
F.overdrive, waveforms, gain=45, colour=30)
def test_phaser(self): def test_phaser(self):
sample_rate = 44100 sample_rate = 44100
n_channels = 2 n_channels = 2
waveform = common_utils.get_whitenoise( waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, n_channels=self.batch_size * n_channels, sample_rate=sample_rate, n_channels=self.batch_size * n_channels, duration=1
duration=1) )
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.phaser, batch, sample_rate) self.assert_batch_consistency(F.phaser, batch, sample_rate)
...@@ -186,37 +175,48 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -186,37 +175,48 @@ class TestFunctional(common_utils.TorchaudioTestCase):
sample_rate = 44100 sample_rate = 44100
self.assert_batch_consistency(F.flanger, waveforms, sample_rate) self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
@parameterized.expand(list(itertools.product( @parameterized.expand(
[True, False], # center list(
[True, False], # norm_vars itertools.product(
)), name_func=_name_from_args) [True, False], # center
[True, False], # norm_vars
)
),
name_func=_name_from_args,
)
def test_sliding_window_cmn(self, center, norm_vars): def test_sliding_window_cmn(self, center, norm_vars):
torch.manual_seed(0) torch.manual_seed(0)
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200 spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
self.assert_batch_consistency( self.assert_batch_consistency(F.sliding_window_cmn, spectrogram, center=center, norm_vars=norm_vars)
F.sliding_window_cmn, spectrogram, center=center,
norm_vars=norm_vars)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform(self, resampling_method): def test_resample_waveform(self, resampling_method):
num_channels = 3 num_channels = 3
sr = 16000 sr = 16000
new_sr = sr // 2 new_sr = sr // 2
multi_sound = common_utils.get_whitenoise(sample_rate=sr, n_channels=num_channels, duration=0.5,) multi_sound = common_utils.get_whitenoise(
sample_rate=sr,
n_channels=num_channels,
duration=0.5,
)
self.assert_batch_consistency( self.assert_batch_consistency(
F.resample, multi_sound, orig_freq=sr, new_freq=new_sr, F.resample,
resampling_method=resampling_method, rtol=1e-4, atol=1e-7) multi_sound,
orig_freq=sr,
new_freq=new_sr,
resampling_method=resampling_method,
rtol=1e-4,
atol=1e-7,
)
@common_utils.skipIfNoKaldi @common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
sample_rate = 44100 sample_rate = 44100
n_channels = 2 n_channels = 2
waveform = common_utils.get_whitenoise( waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency( self.assert_batch_consistency(F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
def test_lfilter(self): def test_lfilter(self):
signal_length = 2048 signal_length = 2048
...@@ -226,10 +226,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -226,10 +226,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
b = torch.rand(self.batch_size, 3) b = torch.rand(self.batch_size, 3)
batchwise_output = F.lfilter(x, a, b, batching=True) batchwise_output = F.lfilter(x, a, b, batching=True)
itemwise_output = torch.stack([ itemwise_output = torch.stack([F.lfilter(x[i], a[i], b[i]) for i in range(self.batch_size)])
F.lfilter(x[i], a[i], b[i])
for i in range(self.batch_size)
])
self.assertEqual(batchwise_output, itemwise_output) self.assertEqual(batchwise_output, itemwise_output)
...@@ -241,9 +238,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -241,9 +238,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
b = torch.rand(self.batch_size, 3) b = torch.rand(self.batch_size, 3)
batchwise_output = F.filtfilt(x, a, b) batchwise_output = F.filtfilt(x, a, b)
itemwise_output = torch.stack([ itemwise_output = torch.stack([F.filtfilt(x[i], a[i], b[i]) for i in range(self.batch_size)])
F.filtfilt(x[i], a[i], b[i])
for i in range(self.batch_size)
])
self.assertEqual(batchwise_output, itemwise_output) self.assertEqual(batchwise_output, itemwise_output)
import unittest
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
import unittest
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import PytorchTestCase, TorchaudioTestCase, skipIfNoSox from torchaudio_unittest.common_utils import PytorchTestCase, TorchaudioTestCase, skipIfNoSox
from .functional_impl import Functional, FunctionalCPUOnly from .functional_impl import Functional, FunctionalCPUOnly
class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
@unittest.expectedFailure @unittest.expectedFailure
def test_lfilter_9th_order_filter_stability(self): def test_lfilter_9th_order_filter_stability(self):
...@@ -18,7 +19,7 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): ...@@ -18,7 +19,7 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
@skipIfNoSox @skipIfNoSox
...@@ -36,12 +37,7 @@ class TestApplyCodec(TorchaudioTestCase): ...@@ -36,12 +37,7 @@ class TestApplyCodec(TorchaudioTestCase):
num_channels = 2 num_channels = 2
waveform = torch.rand(num_channels, num_frames) waveform = torch.rand(num_channels, num_frames)
augmented = F.apply_codec(waveform, augmented = F.apply_codec(waveform, sample_rate, format, True, compression)
sample_rate,
format,
True,
compression
)
assert augmented.dtype == waveform.dtype assert augmented.dtype == waveform.dtype
assert augmented.shape[0] == num_channels assert augmented.shape[0] == num_channels
if check_num_frames: if check_num_frames:
......
import torch
import unittest import unittest
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_impl import Functional from .functional_impl import Functional
@skipIfNoCuda @skipIfNoCuda
class TestFunctionalFloat32(Functional, PytorchTestCase): class TestFunctionalFloat32(Functional, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device("cuda")
@unittest.expectedFailure @unittest.expectedFailure
def test_lfilter_9th_order_filter_stability(self): def test_lfilter_9th_order_filter_stability(self):
...@@ -18,4 +19,4 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): ...@@ -18,4 +19,4 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
@skipIfNoCuda @skipIfNoCuda
class TestLFilterFloat64(Functional, PytorchTestCase): class TestLFilterFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
"""Test definition common to CPU and CUDA""" """Test definition common to CPU and CUDA"""
import math
import itertools import itertools
import math
import warnings import warnings
import numpy as np import numpy as np
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
from scipy import signal from scipy import signal
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
get_sinusoid, get_sinusoid,
...@@ -19,8 +18,9 @@ from torchaudio_unittest.common_utils import ( ...@@ -19,8 +18,9 @@ from torchaudio_unittest.common_utils import (
class Functional(TestBaseMixin): class Functional(TestBaseMixin):
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, def _test_resample_waveform_accuracy(
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4): self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4
):
# resample the signal and compare it to the ground truth # resample the signal and compare it to the ground truth
n_to_trim = 20 n_to_trim = 20
sample_rate = 1000 sample_rate = 1000
...@@ -36,10 +36,9 @@ class Functional(TestBaseMixin): ...@@ -36,10 +36,9 @@ class Functional(TestBaseMixin):
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = F.resample(sound, sample_rate, new_sample_rate, estimate = F.resample(sound, sample_rate, new_sample_rate, resampling_method=resampling_method).squeeze()
resampling_method=resampling_method).squeeze()
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[: estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
# trim the first/last n samples as these points have boundary effects # trim the first/last n samples as these points have boundary effects
...@@ -48,9 +47,7 @@ class Functional(TestBaseMixin): ...@@ -48,9 +47,7 @@ class Functional(TestBaseMixin):
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol) self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
def _test_costs_and_gradients( def _test_costs_and_gradients(self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2):
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape logits_shape = data["logits"].shape
costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data) costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
...@@ -81,15 +78,41 @@ class Functional(TestBaseMixin): ...@@ -81,15 +78,41 @@ class Functional(TestBaseMixin):
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False) output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
assert output_signal.max() > 1 assert output_signal.max() > 1
@parameterized.expand([ @parameterized.expand(
((44100,), (4,), (44100,)), [
((3, 44100), (4,), (3, 44100,)), ((44100,), (4,), (44100,)),
((2, 3, 44100), (4,), (2, 3, 44100,)), (
((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)), (3, 44100),
((44100,), (2, 4), (2, 44100)), (4,),
((3, 44100), (1, 4), (3, 1, 44100)), (
((1, 2, 44100), (3, 4), (1, 2, 3, 44100)) 3,
]) 44100,
),
),
(
(2, 3, 44100),
(4,),
(
2,
3,
44100,
),
),
(
(1, 2, 3, 44100),
(4,),
(
1,
2,
3,
44100,
),
),
((44100,), (2, 4), (2, 44100)),
((3, 44100), (1, 4), (3, 1, 44100)),
((1, 2, 44100), (3, 4), (1, 2, 3, 44100)),
]
)
def test_lfilter_shape(self, input_shape, coeff_shape, target_shape): def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
torch.random.manual_seed(42) torch.random.manual_seed(42)
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
...@@ -109,13 +132,12 @@ class Functional(TestBaseMixin): ...@@ -109,13 +132,12 @@ class Functional(TestBaseMixin):
x[0] = 1 x[0] = 1
# get target impulse response # get target impulse response
sos = signal.butter(9, 850, 'hp', fs=22050, output='sos') sos = signal.butter(9, 850, "hp", fs=22050, output="sos")
y = torch.from_numpy(signal.sosfilt(sos, x.cpu().numpy())).to(self.dtype).to(self.device) y = torch.from_numpy(signal.sosfilt(sos, x.cpu().numpy())).to(self.dtype).to(self.device)
# get lfilter coefficients # get lfilter coefficients
b, a = signal.butter(9, 850, 'hp', fs=22050, output='ba') b, a = signal.butter(9, 850, "hp", fs=22050, output="ba")
b, a = torch.from_numpy(b).to(self.dtype).to(self.device), torch.from_numpy( b, a = torch.from_numpy(b).to(self.dtype).to(self.device), torch.from_numpy(a).to(self.dtype).to(self.device)
a).to(self.dtype).to(self.device)
# predict impulse response # predict impulse response
yhat = F.lfilter(x, a, b, False) yhat = F.lfilter(x, a, b, False)
...@@ -126,14 +148,10 @@ class Functional(TestBaseMixin): ...@@ -126,14 +148,10 @@ class Functional(TestBaseMixin):
Check that, for an arbitrary signal, applying filtfilt with filter coefficients Check that, for an arbitrary signal, applying filtfilt with filter coefficients
corresponding to a pure delay filter imparts no time delay. corresponding to a pure delay filter imparts no time delay.
""" """
waveform = get_whitenoise(sample_rate=8000, n_channels=2, dtype=self.dtype).to( waveform = get_whitenoise(sample_rate=8000, n_channels=2, dtype=self.dtype).to(device=self.device)
device=self.device
)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device) a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
padded_waveform = torch.cat( padded_waveform = torch.cat((waveform, torch.zeros(2, 3, dtype=self.dtype, device=self.device)), axis=1)
(waveform, torch.zeros(2, 3, dtype=self.dtype, device=self.device)), axis=1
)
output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs) output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs)
self.assertEqual(output_waveform, padded_waveform, atol=1e-5, rtol=1e-5) self.assertEqual(output_waveform, padded_waveform, atol=1e-5, rtol=1e-5)
...@@ -147,9 +165,9 @@ class Functional(TestBaseMixin): ...@@ -147,9 +165,9 @@ class Functional(TestBaseMixin):
T = 1.0 T = 1.0
samples = 1000 samples = 1000
waveform_k0 = get_sinusoid( waveform_k0 = get_sinusoid(frequency=5, sample_rate=samples // T, dtype=self.dtype, device=self.device).squeeze(
frequency=5, sample_rate=samples // T, dtype=self.dtype, device=self.device 0
).squeeze(0) )
waveform_k1 = get_sinusoid( waveform_k1 = get_sinusoid(
frequency=200, frequency=200,
sample_rate=samples // T, sample_rate=samples // T,
...@@ -202,13 +220,13 @@ class Functional(TestBaseMixin): ...@@ -202,13 +220,13 @@ class Functional(TestBaseMixin):
# Remove padding from output waveform; confirm that result # Remove padding from output waveform; confirm that result
# closely matches waveform_k0. # closely matches waveform_k0.
self.assertEqual( self.assertEqual(
output_waveform[samples - 1: 2 * samples - 1], output_waveform[samples - 1 : 2 * samples - 1],
waveform_k0, waveform_k0,
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
) )
@parameterized.expand([(0., ), (1., ), (2., ), (3., )]) @parameterized.expand([(0.0,), (1.0,), (2.0,), (3.0,)])
def test_spectrogram_grad_at_zero(self, power): def test_spectrogram_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0 """The gradient of power spectrogram should not be nan but zero near x=0
...@@ -235,19 +253,15 @@ class Functional(TestBaseMixin): ...@@ -235,19 +253,15 @@ class Functional(TestBaseMixin):
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_compute_deltas_two_channels(self): def test_compute_deltas_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
@parameterized.expand([(100,), (440,)]) @parameterized.expand([(100,), (440,)])
def test_detect_pitch_frequency_pitch(self, frequency): def test_detect_pitch_frequency_pitch(self, frequency):
sample_rate = 44100 sample_rate = 44100
test_sine_waveform = get_sinusoid( test_sine_waveform = get_sinusoid(frequency=frequency, sample_rate=sample_rate, duration=5)
frequency=frequency, sample_rate=sample_rate, duration=5
)
freq = F.detect_pitch_frequency(test_sine_waveform, sample_rate) freq = F.detect_pitch_frequency(test_sine_waveform, sample_rate)
...@@ -262,8 +276,8 @@ class Functional(TestBaseMixin): ...@@ -262,8 +276,8 @@ class Functional(TestBaseMixin):
This implicitly also tests `DB_to_amplitude`. This implicitly also tests `DB_to_amplitude`.
""" """
amplitude_mult = 20. amplitude_mult = 20.0
power_mult = 10. power_mult = 10.0
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
...@@ -279,18 +293,18 @@ class Functional(TestBaseMixin): ...@@ -279,18 +293,18 @@ class Functional(TestBaseMixin):
# Spectrogram power -> DB -> power # Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None) db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 1.) x2 = F.DB_to_amplitude(db, ref, 1.0)
self.assertEqual(x2, spec) self.assertEqual(x2, spec)
@parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)]) @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
def test_amplitude_to_DB_top_db_clamp(self, shape): def test_amplitude_to_DB_top_db_clamp(self, shape):
"""Ensure values are properly clamped when `top_db` is supplied.""" """Ensure values are properly clamped when `top_db` is supplied."""
amplitude_mult = 20. amplitude_mult = 20.0
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
top_db = 40. top_db = 40.0
torch.manual_seed(0) torch.manual_seed(0)
# A random tensor is used for increased entropy, but the max and min for # A random tensor is used for increased entropy, but the max and min for
...@@ -304,24 +318,17 @@ class Functional(TestBaseMixin): ...@@ -304,24 +318,17 @@ class Functional(TestBaseMixin):
# Expand the range to (0, 200) - wide enough to properly test clamping. # Expand the range to (0, 200) - wide enough to properly test clamping.
spec *= 200 spec *= 200
decibels = F.amplitude_to_DB(spec, amplitude_mult, amin, decibels = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=top_db)
db_mult, top_db=top_db)
# Ensure the clamp was applied # Ensure the clamp was applied
below_limit = decibels < 6.0205 below_limit = decibels < 6.0205
assert not below_limit.any(), ( assert not below_limit.any(), "{} decibel values were below the expected cutoff:\n{}".format(
"{} decibel values were below the expected cutoff:\n{}".format( below_limit.sum().item(), decibels
below_limit.sum().item(), decibels
)
) )
# Ensure it didn't over-clamp # Ensure it didn't over-clamp
close_to_limit = decibels < 6.0207 close_to_limit = decibels < 6.0207
assert close_to_limit.any(), ( assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}"
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)
@parameterized.expand( @parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2])))
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
)
def test_mask_along_axis(self, shape, mask_param, mask_value, axis): def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
torch.random.manual_seed(42) torch.random.manual_seed(42)
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device) specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
...@@ -331,13 +338,12 @@ class Functional(TestBaseMixin): ...@@ -331,13 +338,12 @@ class Functional(TestBaseMixin):
masked_columns = (mask_specgram == mask_value).sum(other_axis) masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum() num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div( num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor")
num_masked_columns, mask_specgram.size(0), rounding_mode='floor')
assert mask_specgram.size() == specgram.size() assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param assert num_masked_columns < mask_param
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3]))) @parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis): def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
torch.random.manual_seed(42) torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device) specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
...@@ -352,9 +358,7 @@ class Functional(TestBaseMixin): ...@@ -352,9 +358,7 @@ class Functional(TestBaseMixin):
assert mask_specgrams.size() == specgrams.size() assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@parameterized.expand( @parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2])))
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
)
def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis): def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis):
"""mask_along_axis should not alter original input Tensor """mask_along_axis should not alter original input Tensor
...@@ -369,7 +373,7 @@ class Functional(TestBaseMixin): ...@@ -369,7 +373,7 @@ class Functional(TestBaseMixin):
self.assertEqual(specgram, specgram_copy) self.assertEqual(specgram, specgram_copy)
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3]))) @parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3])))
def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis): def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):
"""mask_along_axis_iid should not alter original input Tensor """mask_along_axis_iid should not alter original input Tensor
...@@ -384,10 +388,14 @@ class Functional(TestBaseMixin): ...@@ -384,10 +388,14 @@ class Functional(TestBaseMixin):
self.assertEqual(specgrams, specgrams_copy) self.assertEqual(specgrams, specgrams_copy)
@parameterized.expand(list(itertools.product( @parameterized.expand(
["sinc_interpolation", "kaiser_window"], list(
[16000, 44100], itertools.product(
))) ["sinc_interpolation", "kaiser_window"],
[16000, 44100],
)
)
)
def test_resample_identity(self, resampling_method, sample_rate): def test_resample_identity(self, resampling_method, sample_rate):
waveform = get_whitenoise(sample_rate=sample_rate, duration=1) waveform = get_whitenoise(sample_rate=sample_rate, duration=1)
...@@ -397,35 +405,52 @@ class Functional(TestBaseMixin): ...@@ -397,35 +405,52 @@ class Functional(TestBaseMixin):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method): def test_resample_waveform_upsample_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise(sample_rate=sr, duration=0.5,) waveform = get_whitenoise(
sample_rate=sr,
duration=0.5,
)
upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method) upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
assert upsampled.size(-1) == waveform.size(-1) * 2 assert upsampled.size(-1) == waveform.size(-1) * 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method): def test_resample_waveform_downsample_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise(sample_rate=sr, duration=0.5,) waveform = get_whitenoise(
sample_rate=sr,
duration=0.5,
)
downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method) downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
assert downsampled.size(-1) == waveform.size(-1) // 2 assert downsampled.size(-1) == waveform.size(-1) // 2
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method): def test_resample_waveform_identity_size(self, resampling_method):
sr = 16000 sr = 16000
waveform = get_whitenoise(sample_rate=sr, duration=0.5,) waveform = get_whitenoise(
sample_rate=sr,
duration=0.5,
)
resampled = F.resample(waveform, sr, sr, resampling_method=resampling_method) resampled = F.resample(waveform, sr, sr, resampling_method=resampling_method)
assert resampled.size(-1) == waveform.size(-1) assert resampled.size(-1) == waveform.size(-1)
@parameterized.expand(list(itertools.product( @parameterized.expand(
["sinc_interpolation", "kaiser_window"], list(
list(range(1, 20)), itertools.product(
))) ["sinc_interpolation", "kaiser_window"],
list(range(1, 20)),
)
)
)
def test_resample_waveform_downsample_accuracy(self, resampling_method, i): def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method) self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)
@parameterized.expand(list(itertools.product( @parameterized.expand(
["sinc_interpolation", "kaiser_window"], list(
list(range(1, 20)), itertools.product(
))) ["sinc_interpolation", "kaiser_window"],
list(range(1, 20)),
)
)
)
def test_resample_waveform_upsample_accuracy(self, resampling_method, i): def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method) self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
...@@ -438,14 +463,9 @@ class Functional(TestBaseMixin): ...@@ -438,14 +463,9 @@ class Functional(TestBaseMixin):
batch_size = 2 batch_size = 2
torch.random.manual_seed(42) torch.random.manual_seed(42)
spec = torch.randn( spec = torch.randn(batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
phase_advance = torch.linspace( phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, dtype=self.dtype, device=self.device)[..., None]
0,
np.pi * hop_length,
num_freq,
dtype=self.dtype, device=self.device)[..., None]
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
...@@ -460,32 +480,31 @@ class Functional(TestBaseMixin): ...@@ -460,32 +480,31 @@ class Functional(TestBaseMixin):
["", "", 0], # equal ["", "", 0], # equal
["abc", "abc", 0], ["abc", "abc", 0],
["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0], ["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0],
["abc", "", 3], # deletion ["abc", "", 3], # deletion
["aa", "aaa", 1], ["aa", "aaa", 1],
["aaa", "aa", 1], ["aaa", "aa", 1],
["ᑌᑎI", "ᑌᑎIᑕO", 2], ["ᑌᑎI", "ᑌᑎIᑕO", 2],
["aaa", "aba", 1], # substitution ["aaa", "aba", 1], # substitution
["aba", "aaa", 1], ["aba", "aaa", 1],
["aba", " ", 3], ["aba", " ", 3],
["abc", "bcd", 2], # mix deletion and substitution ["abc", "bcd", 2], # mix deletion and substitution
["0ᑌᑎI", "ᑌᑎIᑕO", 3], ["0ᑌᑎI", "ᑌᑎIᑕO", 3],
# sentences # sentences
[["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0], # equal [["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0], # equal
[[], [], 0], [[], [], 0],
[["hello", "world"], ["hello", "world", "!"], 1], # deletion [["hello", "world"], ["hello", "world", "!"], 1], # deletion
[["hello", "world"], ["world"], 1], [["hello", "world"], ["world"], 1],
[["hello", "world"], [], 2], [["hello", "world"], [], 2],
[
[["Tᕮ᙭T", ], ["world"], 1], # substitution [
"Tᕮ᙭T",
],
["world"],
1,
], # substitution
[["Tᕮ᙭T", "XD"], ["world", "hello"], 2], [["Tᕮ᙭T", "XD"], ["world", "hello"], 2],
[["", "XD"], ["world", ""], 2], [["", "XD"], ["world", ""], 2],
["aba", " ", 3], ["aba", " ", 3],
[["hello", "world"], ["world", "hello", "!"], 2], # mix deletion and substitution [["hello", "world"], ["world", "hello", "!"], 2], # mix deletion and substitution
[["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3], [["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3],
] ]
...@@ -520,12 +539,14 @@ class Functional(TestBaseMixin): ...@@ -520,12 +539,14 @@ class Functional(TestBaseMixin):
logits.requires_grad_(False) logits.requires_grad_(False)
F.rnnt_loss(logits, targets, logit_lengths, target_lengths) F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
@parameterized.expand([ @parameterized.expand(
(rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2), [
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2), (rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2),
(rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2), (rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2),
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2), (rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2),
]) (rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
]
)
def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol): def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
data, ref_costs, ref_gradients = data_func( data, ref_costs, ref_gradients = data_func(
dtype=dtype, dtype=dtype,
...@@ -544,9 +565,7 @@ class Functional(TestBaseMixin): ...@@ -544,9 +565,7 @@ class Functional(TestBaseMixin):
for i in range(5): for i in range(5):
data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i)) data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data) ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients( self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase): class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
class TestKaldiFloat32(Kaldi, PytorchTestCase): class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
class TestKaldiFloat64(Kaldi, PytorchTestCase): class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment