Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .kaldi_compatibility_test_impl import Kaldi from .kaldi_compatibility_test_impl import Kaldi
@skipIfNoCuda @skipIfNoCuda
class TestKaldiFloat32(Kaldi, PytorchTestCase): class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device("cuda")
@skipIfNoCuda @skipIfNoCuda
class TestKaldiFloat64(Kaldi, PytorchTestCase): class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
from parameterized import parameterized
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_sinusoid, get_sinusoid,
load_params, load_params,
...@@ -21,20 +20,20 @@ class Kaldi(TempDirMixin, TestBaseMixin): ...@@ -21,20 +20,20 @@ class Kaldi(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
@skipIfNoExec('apply-cmvn-sliding') @skipIfNoExec("apply-cmvn-sliding")
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = { kwargs = {
'cmn_window': 600, "cmn_window": 600,
'min_cmn_window': 100, "min_cmn_window": 100,
'center': False, "center": False,
'norm_vars': False, "norm_vars": False,
} }
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device) tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs) result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-'] command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"]
kaldi_result = run_kaldi(command, 'ark', tensor) kaldi_result = run_kaldi(command, "ark", tensor)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
...@@ -43,18 +42,18 @@ class KaldiCPUOnly(TempDirMixin, TestBaseMixin): ...@@ -43,18 +42,18 @@ class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params('kaldi_test_pitch_args.jsonl')) @parameterized.expand(load_params("kaldi_test_pitch_args.jsonl"))
@skipIfNoExec('compute-kaldi-pitch-feats') @skipIfNoExec("compute-kaldi-pitch-feats")
def test_pitch_feats(self, kwargs): def test_pitch_feats(self, kwargs):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats""" """compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate = kwargs['sample_rate'] sample_rate = kwargs["sample_rate"]
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate) waveform = get_sinusoid(dtype="float32", sample_rate=sample_rate)
result = F.compute_kaldi_pitch(waveform[0], **kwargs) result = F.compute_kaldi_pitch(waveform[0], **kwargs)
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate) waveform = get_sinusoid(dtype="int16", sample_rate=sample_rate)
wave_file = self.get_temp_path('test.wav') wave_file = self.get_temp_path("test.wav")
save_wav(wave_file, waveform, sample_rate) save_wav(wave_file, waveform, sample_rate)
command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ["compute-kaldi-pitch-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import Functional, FunctionalComplex from .librosa_compatibility_test_impl import Functional, FunctionalComplex
class TestFunctionalCPU(Functional, PytorchTestCase): class TestFunctionalCPU(Functional, PytorchTestCase):
device = 'cpu' device = "cpu"
class TestFunctionalComplexCPU(FunctionalComplex, PytorchTestCase): class TestFunctionalComplexCPU(FunctionalComplex, PytorchTestCase):
device = 'cpu' device = "cpu"
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import Functional, FunctionalComplex from .librosa_compatibility_test_impl import Functional, FunctionalComplex
@skipIfNoCuda @skipIfNoCuda
class TestFunctionalCUDA(Functional, PytorchTestCase): class TestFunctionalCUDA(Functional, PytorchTestCase):
device = 'cuda' device = "cuda"
@skipIfNoCuda @skipIfNoCuda
class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase): class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase):
device = 'cuda' device = "cuda"
...@@ -2,16 +2,15 @@ import unittest ...@@ -2,16 +2,15 @@ import unittest
from distutils.version import StrictVersion from distutils.version import StrictVersion
import torch import torch
from parameterized import param
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import param
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_module_available
LIBROSA_AVAILABLE = is_module_available('librosa') LIBROSA_AVAILABLE = is_module_available("librosa")
if LIBROSA_AVAILABLE: if LIBROSA_AVAILABLE:
import numpy as np
import librosa import librosa
import numpy as np
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -25,6 +24,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -25,6 +24,7 @@ from torchaudio_unittest.common_utils import (
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class Functional(TestBaseMixin): class Functional(TestBaseMixin):
"""Test suite for functions in `functional` module.""" """Test suite for functions in `functional` module."""
dtype = torch.float64 dtype = torch.float64
@nested_params([0, 0.99]) @nested_params([0, 0.99])
...@@ -40,8 +40,8 @@ class Functional(TestBaseMixin): ...@@ -40,8 +40,8 @@ class Functional(TestBaseMixin):
waveform = get_whitenoise(device=self.device, dtype=self.dtype) waveform = get_whitenoise(device=self.device, dtype=self.dtype)
specgram = get_spectrogram( specgram = get_spectrogram(
waveform, n_fft=n_fft, hop_length=hop_length, power=power, waveform, n_fft=n_fft, hop_length=hop_length, power=power, win_length=win_length, window=window
win_length=win_length, window=window) )
result = F.griffinlim( result = F.griffinlim(
specgram, specgram,
...@@ -53,14 +53,16 @@ class Functional(TestBaseMixin): ...@@ -53,14 +53,16 @@ class Functional(TestBaseMixin):
n_iter=n_iter, n_iter=n_iter,
momentum=momentum, momentum=momentum,
length=waveform.size(1), length=waveform.size(1),
rand_init=False) rand_init=False,
)
expected = librosa.griffinlim( expected = librosa.griffinlim(
specgram[0].cpu().numpy(), specgram[0].cpu().numpy(),
n_iter=n_iter, n_iter=n_iter,
hop_length=hop_length, hop_length=hop_length,
momentum=momentum, momentum=momentum,
init=None, init=None,
length=waveform.size(1))[None, ...] length=waveform.size(1),
)[None, ...]
self.assertEqual(result, torch.from_numpy(expected), atol=5e-5, rtol=1e-07) self.assertEqual(result, torch.from_numpy(expected), atol=5e-5, rtol=1e-07)
@nested_params( @nested_params(
...@@ -73,24 +75,20 @@ class Functional(TestBaseMixin): ...@@ -73,24 +75,20 @@ class Functional(TestBaseMixin):
param(n_mels=56, fmin=1900.0, fmax=900.0), param(n_mels=56, fmin=1900.0, fmax=900.0),
param(n_mels=10, fmin=1900.0, fmax=900.0), param(n_mels=10, fmin=1900.0, fmax=900.0),
], ],
[param(norm=n) for n in [None, 'slaney']], [param(norm=n) for n in [None, "slaney"]],
[param(mel_scale=s) for s in ['htk', 'slaney']], [param(mel_scale=s) for s in ["htk", "slaney"]],
) )
def test_create_mel_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, def test_create_mel_fb(
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"): self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"
if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")): ):
self.skipTest('Test is known to fail with older versions of librosa.') if norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
if self.device != 'cpu': self.skipTest("Test is known to fail with older versions of librosa.")
self.skipTest('No need to run this test on CUDA') if self.device != "cpu":
self.skipTest("No need to run this test on CUDA")
expected = librosa.filters.mel( expected = librosa.filters.mel(
sr=sample_rate, sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmax=fmax, fmin=fmin, htk=mel_scale == "htk", norm=norm
n_fft=n_fft, ).T
n_mels=n_mels,
fmax=fmax,
fmin=fmin,
htk=mel_scale == "htk",
norm=norm).T
result = F.melscale_fbanks( result = F.melscale_fbanks(
sample_rate=sample_rate, sample_rate=sample_rate,
n_mels=n_mels, n_mels=n_mels,
...@@ -98,7 +96,8 @@ class Functional(TestBaseMixin): ...@@ -98,7 +96,8 @@ class Functional(TestBaseMixin):
f_min=fmin, f_min=fmin,
n_freqs=(n_fft // 2 + 1), n_freqs=(n_fft // 2 + 1),
norm=norm, norm=norm,
mel_scale=mel_scale) mel_scale=mel_scale,
)
self.assertEqual(result, torch.from_numpy(expected), atol=7e-5, rtol=1.3e-6) self.assertEqual(result, torch.from_numpy(expected), atol=7e-5, rtol=1.3e-6)
def test_amplitude_to_DB_power(self): def test_amplitude_to_DB_power(self):
...@@ -137,18 +136,12 @@ class FunctionalComplex(TestBaseMixin): ...@@ -137,18 +136,12 @@ class FunctionalComplex(TestBaseMixin):
# result in bottom right values of the stretched sectrogram to not # result in bottom right values of the stretched sectrogram to not
# match with librosa. # match with librosa.
spec = torch.randn(num_freq, num_frames, device=self.device, dtype=torch.complex128) spec = torch.randn(num_freq, num_frames, device=self.device, dtype=torch.complex128)
phase_advance = torch.linspace( phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, device=self.device, dtype=torch.float64)[
0, ..., None
np.pi * hop_length, ]
num_freq,
device=self.device,
dtype=torch.float64)[..., None]
stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
expected_stretched = librosa.phase_vocoder( expected_stretched = librosa.phase_vocoder(spec.cpu().numpy(), rate=rate, hop_length=hop_length)
spec.cpu().numpy(),
rate=rate,
hop_length=hop_length)
self.assertEqual(stretched, torch.from_numpy(expected_stretched)) self.assertEqual(stretched, torch.from_numpy(expected_stretched))
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
skipIfNoSox, skipIfNoSox,
skipIfNoExec, skipIfNoExec,
...@@ -15,10 +14,10 @@ from torchaudio_unittest.common_utils import ( ...@@ -15,10 +14,10 @@ from torchaudio_unittest.common_utils import (
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def run_sox_effect(self, input_file, effect): def run_sox_effect(self, input_file, effect):
output_file = self.get_temp_path('expected.wav') output_file = self.get_temp_path("expected.wav")
sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect]) sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect])
return load_wav(output_file) return load_wav(output_file)
...@@ -28,29 +27,31 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -28,29 +27,31 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def get_whitenoise(self, sample_rate=8000): def get_whitenoise(self, sample_rate=8000):
noise = get_whitenoise( noise = get_whitenoise(
sample_rate=sample_rate, duration=3, scale_factor=0.9, sample_rate=sample_rate,
duration=3,
scale_factor=0.9,
) )
path = self.get_temp_path("whitenoise.wav") path = self.get_temp_path("whitenoise.wav")
save_wav(path, noise, sample_rate) save_wav(path, noise, sample_rate)
return noise, path return noise, path
def test_gain(self): def test_gain(self):
path = get_asset_path('steam-train-whistle-daniel_simon.wav') path = get_asset_path("steam-train-whistle-daniel_simon.wav")
data, _ = load_wav(path) data, _ = load_wav(path)
result = F.gain(data, 3) result = F.gain(data, 3)
self.assert_sox_effect(result, path, ['gain', 3]) self.assert_sox_effect(result, path, ["gain", 3])
def test_dither(self): def test_dither(self):
path = get_asset_path('steam-train-whistle-daniel_simon.wav') path = get_asset_path("steam-train-whistle-daniel_simon.wav")
data, _ = load_wav(path) data, _ = load_wav(path)
result = F.dither(data) result = F.dither(data)
self.assert_sox_effect(result, path, ['dither']) self.assert_sox_effect(result, path, ["dither"])
def test_dither_noise(self): def test_dither_noise(self):
path = get_asset_path('steam-train-whistle-daniel_simon.wav') path = get_asset_path("steam-train-whistle-daniel_simon.wav")
data, _ = load_wav(path) data, _ = load_wav(path)
result = F.dither(data, noise_shaping=True) result = F.dither(data, noise_shaping=True)
self.assert_sox_effect(result, path, ['dither', '-s'], atol=1.5e-4) self.assert_sox_effect(result, path, ["dither", "-s"], atol=1.5e-4)
def test_lowpass(self): def test_lowpass(self):
cutoff_freq = 3000 cutoff_freq = 3000
...@@ -58,7 +59,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -58,7 +59,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.lowpass_biquad(data, sample_rate, cutoff_freq) result = F.lowpass_biquad(data, sample_rate, cutoff_freq)
self.assert_sox_effect(result, path, ['lowpass', cutoff_freq], atol=1.5e-4) self.assert_sox_effect(result, path, ["lowpass", cutoff_freq], atol=1.5e-4)
def test_highpass(self): def test_highpass(self):
cutoff_freq = 2000 cutoff_freq = 2000
...@@ -66,7 +67,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -66,7 +67,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.highpass_biquad(data, sample_rate, cutoff_freq) result = F.highpass_biquad(data, sample_rate, cutoff_freq)
self.assert_sox_effect(result, path, ['highpass', cutoff_freq], atol=1.5e-4) self.assert_sox_effect(result, path, ["highpass", cutoff_freq], atol=1.5e-4)
def test_allpass(self): def test_allpass(self):
central_freq = 1000 central_freq = 1000
...@@ -75,7 +76,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -75,7 +76,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.allpass_biquad(data, sample_rate, central_freq, q) result = F.allpass_biquad(data, sample_rate, central_freq, q)
self.assert_sox_effect(result, path, ['allpass', central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["allpass", central_freq, f"{q}q"])
def test_bandpass_with_csg(self): def test_bandpass_with_csg(self):
central_freq = 1000 central_freq = 1000
...@@ -85,7 +86,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -85,7 +86,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain) result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain)
self.assert_sox_effect(result, path, ['bandpass', '-c', central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["bandpass", "-c", central_freq, f"{q}q"])
def test_bandpass_without_csg(self): def test_bandpass_without_csg(self):
central_freq = 1000 central_freq = 1000
...@@ -95,7 +96,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -95,7 +96,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain) result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain)
self.assert_sox_effect(result, path, ['bandpass', central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["bandpass", central_freq, f"{q}q"])
def test_bandreject(self): def test_bandreject(self):
central_freq = 1000 central_freq = 1000
...@@ -104,7 +105,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -104,7 +105,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.bandreject_biquad(data, sample_rate, central_freq, q) result = F.bandreject_biquad(data, sample_rate, central_freq, q)
self.assert_sox_effect(result, path, ['bandreject', central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["bandreject", central_freq, f"{q}q"])
def test_band_with_noise(self): def test_band_with_noise(self):
central_freq = 1000 central_freq = 1000
...@@ -114,7 +115,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -114,7 +115,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.band_biquad(data, sample_rate, central_freq, q, noise) result = F.band_biquad(data, sample_rate, central_freq, q, noise)
self.assert_sox_effect(result, path, ['band', '-n', central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["band", "-n", central_freq, f"{q}q"])
def test_band_without_noise(self): def test_band_without_noise(self):
central_freq = 1000 central_freq = 1000
...@@ -124,7 +125,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -124,7 +125,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.band_biquad(data, sample_rate, central_freq, q, noise) result = F.band_biquad(data, sample_rate, central_freq, q, noise)
self.assert_sox_effect(result, path, ['band', central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["band", central_freq, f"{q}q"])
def test_treble(self): def test_treble(self):
central_freq = 1000 central_freq = 1000
...@@ -134,7 +135,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -134,7 +135,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.treble_biquad(data, sample_rate, gain, central_freq, q) result = F.treble_biquad(data, sample_rate, gain, central_freq, q)
self.assert_sox_effect(result, path, ['treble', gain, central_freq, f'{q}q']) self.assert_sox_effect(result, path, ["treble", gain, central_freq, f"{q}q"])
def test_bass(self): def test_bass(self):
central_freq = 1000 central_freq = 1000
...@@ -144,26 +145,26 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -144,26 +145,26 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.bass_biquad(data, sample_rate, gain, central_freq, q) result = F.bass_biquad(data, sample_rate, gain, central_freq, q)
self.assert_sox_effect(result, path, ['bass', gain, central_freq, f'{q}q'], atol=1.5e-4) self.assert_sox_effect(result, path, ["bass", gain, central_freq, f"{q}q"], atol=1.5e-4)
def test_deemph(self): def test_deemph(self):
sample_rate = 44100 sample_rate = 44100
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.deemph_biquad(data, sample_rate) result = F.deemph_biquad(data, sample_rate)
self.assert_sox_effect(result, path, ['deemph']) self.assert_sox_effect(result, path, ["deemph"])
def test_riaa(self): def test_riaa(self):
sample_rate = 44100 sample_rate = 44100
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.riaa_biquad(data, sample_rate) result = F.riaa_biquad(data, sample_rate)
self.assert_sox_effect(result, path, ['riaa']) self.assert_sox_effect(result, path, ["riaa"])
def test_contrast(self): def test_contrast(self):
enhancement_amount = 80. enhancement_amount = 80.0
data, path = self.get_whitenoise() data, path = self.get_whitenoise()
result = F.contrast(data, enhancement_amount) result = F.contrast(data, enhancement_amount)
self.assert_sox_effect(result, path, ['contrast', enhancement_amount]) self.assert_sox_effect(result, path, ["contrast", enhancement_amount])
def test_dcshift_with_limiter(self): def test_dcshift_with_limiter(self):
shift = 0.5 shift = 0.5
...@@ -171,14 +172,14 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -171,14 +172,14 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise() data, path = self.get_whitenoise()
result = F.dcshift(data, shift, limiter_gain) result = F.dcshift(data, shift, limiter_gain)
self.assert_sox_effect(result, path, ['dcshift', shift, limiter_gain]) self.assert_sox_effect(result, path, ["dcshift", shift, limiter_gain])
def test_dcshift_without_limiter(self): def test_dcshift_without_limiter(self):
shift = 0.6 shift = 0.6
data, path = self.get_whitenoise() data, path = self.get_whitenoise()
result = F.dcshift(data, shift) result = F.dcshift(data, shift)
self.assert_sox_effect(result, path, ['dcshift', shift]) self.assert_sox_effect(result, path, ["dcshift", shift])
def test_overdrive(self): def test_overdrive(self):
gain = 30 gain = 30
...@@ -186,7 +187,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -186,7 +187,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise() data, path = self.get_whitenoise()
result = F.overdrive(data, gain, colour) result = F.overdrive(data, gain, colour)
self.assert_sox_effect(result, path, ['overdrive', gain, colour]) self.assert_sox_effect(result, path, ["overdrive", gain, colour])
def test_phaser_sine(self): def test_phaser_sine(self):
gain_in = 0.5 gain_in = 0.5
...@@ -198,7 +199,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -198,7 +199,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True) result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)
self.assert_sox_effect(result, path, ['phaser', gain_in, gain_out, delay_ms, decay, speed, '-s']) self.assert_sox_effect(result, path, ["phaser", gain_in, gain_out, delay_ms, decay, speed, "-s"])
def test_phaser_triangle(self): def test_phaser_triangle(self):
gain_in = 0.5 gain_in = 0.5
...@@ -210,7 +211,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -210,7 +211,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=False) result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=False)
self.assert_sox_effect(result, path, ['phaser', gain_in, gain_out, delay_ms, decay, speed, '-t']) self.assert_sox_effect(result, path, ["phaser", gain_in, gain_out, delay_ms, decay, speed, "-t"])
def test_flanger_triangle_linear(self): def test_flanger_triangle_linear(self):
delay = 0.6 delay = 0.6
...@@ -223,10 +224,11 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -223,10 +224,11 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.flanger( result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase, data, sample_rate, delay, depth, regen, width, speed, phase, modulation="triangular", interpolation="linear"
modulation='triangular', interpolation='linear') )
self.assert_sox_effect( self.assert_sox_effect(
result, path, ['flanger', delay, depth, regen, width, speed, 'triangle', phase, 'linear']) result, path, ["flanger", delay, depth, regen, width, speed, "triangle", phase, "linear"]
)
def test_flanger_triangle_quad(self): def test_flanger_triangle_quad(self):
delay = 0.8 delay = 0.8
...@@ -239,10 +241,20 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -239,10 +241,20 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.flanger( result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase, data,
modulation='triangular', interpolation='quadratic') sample_rate,
delay,
depth,
regen,
width,
speed,
phase,
modulation="triangular",
interpolation="quadratic",
)
self.assert_sox_effect( self.assert_sox_effect(
result, path, ['flanger', delay, depth, regen, width, speed, 'triangle', phase, 'quadratic']) result, path, ["flanger", delay, depth, regen, width, speed, "triangle", phase, "quadratic"]
)
def test_flanger_sine_linear(self): def test_flanger_sine_linear(self):
delay = 0.8 delay = 0.8
...@@ -255,10 +267,9 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -255,10 +267,9 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.flanger( result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase, data, sample_rate, delay, depth, regen, width, speed, phase, modulation="sinusoidal", interpolation="linear"
modulation='sinusoidal', interpolation='linear') )
self.assert_sox_effect( self.assert_sox_effect(result, path, ["flanger", delay, depth, regen, width, speed, "sine", phase, "linear"])
result, path, ['flanger', delay, depth, regen, width, speed, 'sine', phase, 'linear'])
def test_flanger_sine_quad(self): def test_flanger_sine_quad(self):
delay = 0.9 delay = 0.9
...@@ -271,10 +282,18 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -271,10 +282,18 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.flanger( result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase, data,
modulation='sinusoidal', interpolation='quadratic') sample_rate,
self.assert_sox_effect( delay,
result, path, ['flanger', delay, depth, regen, width, speed, 'sine', phase, 'quadratic']) depth,
regen,
width,
speed,
phase,
modulation="sinusoidal",
interpolation="quadratic",
)
self.assert_sox_effect(result, path, ["flanger", delay, depth, regen, width, speed, "sine", phase, "quadratic"])
def test_equalizer(self): def test_equalizer(self):
center_freq = 300 center_freq = 300
...@@ -284,7 +303,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -284,7 +303,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate) data, path = self.get_whitenoise(sample_rate)
result = F.equalizer_biquad(data, sample_rate, center_freq, gain, q) result = F.equalizer_biquad(data, sample_rate, center_freq, gain, q)
self.assert_sox_effect(result, path, ['equalizer', center_freq, q, gain]) self.assert_sox_effect(result, path, ["equalizer", center_freq, q, gain])
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
b0 = 0.4 b0 = 0.4
...@@ -296,4 +315,4 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -296,4 +315,4 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise() data, path = self.get_whitenoise()
result = F.lfilter(data, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])) result = F.lfilter(data, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]))
self.assert_sox_effect(result, path, ['biquad', b0, b1, b2, a0, a1, a2]) self.assert_sox_effect(result, path, ["biquad", b0, b1, b2, a0, a1, a2])
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase): class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
class TestFunctionalFloat64(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
@skipIfNoCuda @skipIfNoCuda
class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase): class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device("cuda")
@skipIfNoCuda @skipIfNoCuda
class TestFunctionalFloat64(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
...@@ -3,7 +3,6 @@ import unittest ...@@ -3,7 +3,6 @@ import unittest
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
...@@ -15,6 +14,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -15,6 +14,7 @@ from torchaudio_unittest.common_utils import (
class Functional(TempDirMixin, TestBaseMixin): class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functional` module that are performed for different devices""" """Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False): def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
ts_func = torch_script(func) ts_func = torch_script(func)
...@@ -79,7 +79,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -79,7 +79,7 @@ class Functional(TempDirMixin, TestBaseMixin):
ws = 400 ws = 400
hop = 200 hop = 200
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2. power = 2.0
momentum = 0.99 momentum = 0.99
n_iter = 32 n_iter = 32
length = 1000 length = 1000
...@@ -110,8 +110,8 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -110,8 +110,8 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_melscale_fbanks(self): def test_melscale_fbanks(self):
if self.device != torch.device('cpu'): if self.device != torch.device("cpu"):
raise unittest.SkipTest('No need to perform test on device other than CPU') raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_): def func(_):
n_stft = 100 n_stft = 100
...@@ -126,8 +126,8 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -126,8 +126,8 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, dummy) self._assert_consistency(func, dummy)
def test_linear_fbanks(self): def test_linear_fbanks(self):
if self.device != torch.device('cpu'): if self.device != torch.device("cpu"):
raise unittest.SkipTest('No need to perform test on device other than CPU') raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_): def func(_):
n_stft = 100 n_stft = 100
...@@ -153,16 +153,16 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -153,16 +153,16 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_DB_to_amplitude(self): def test_DB_to_amplitude(self):
def func(tensor): def func(tensor):
ref = 1. ref = 1.0
power = 1. power = 1.0
return F.DB_to_amplitude(tensor, ref, power) return F.DB_to_amplitude(tensor, ref, power)
tensor = torch.rand((1, 100)) tensor = torch.rand((1, 100))
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
def test_create_dct(self): def test_create_dct(self):
if self.device != torch.device('cpu'): if self.device != torch.device("cpu"):
raise unittest.SkipTest('No need to perform test on device other than CPU') raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_): def func(_):
n_mfcc = 40 n_mfcc = 40
...@@ -192,7 +192,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -192,7 +192,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_mask_along_axis(self): def test_mask_along_axis(self):
def func(tensor): def func(tensor):
mask_param = 100 mask_param = 100
mask_value = 30. mask_value = 30.0
axis = 2 axis = 2
return F.mask_along_axis(tensor, mask_param, mask_value, axis) return F.mask_along_axis(tensor, mask_param, mask_value, axis)
...@@ -202,7 +202,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -202,7 +202,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_mask_along_axis_iid(self): def test_mask_along_axis_iid(self):
def func(tensor): def func(tensor):
mask_param = 100 mask_param = 100
mask_value = 30. mask_value = 30.0
axis = 2 axis = 2
return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis) return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
...@@ -219,21 +219,21 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -219,21 +219,21 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_dither_TPDF(self): def test_dither_TPDF(self):
def func(tensor): def func(tensor):
return F.dither(tensor, 'TPDF') return F.dither(tensor, "TPDF")
tensor = common_utils.get_whitenoise(n_channels=2) tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True) self._assert_consistency(func, tensor, shape_only=True)
def test_dither_RPDF(self): def test_dither_RPDF(self):
def func(tensor): def func(tensor):
return F.dither(tensor, 'RPDF') return F.dither(tensor, "RPDF")
tensor = common_utils.get_whitenoise(n_channels=2) tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True) self._assert_consistency(func, tensor, shape_only=True)
def test_dither_GPDF(self): def test_dither_GPDF(self):
def func(tensor): def func(tensor):
return F.dither(tensor, 'GPDF') return F.dither(tensor, "GPDF")
tensor = common_utils.get_whitenoise(n_channels=2) tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True) self._assert_consistency(func, tensor, shape_only=True)
...@@ -306,7 +306,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -306,7 +306,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
cutoff_freq = 3000. cutoff_freq = 3000.0
return F.lowpass_biquad(tensor, sample_rate, cutoff_freq) return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
...@@ -319,7 +319,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -319,7 +319,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
cutoff_freq = 2000. cutoff_freq = 2000.0
return F.highpass_biquad(tensor, sample_rate, cutoff_freq) return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
...@@ -332,7 +332,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -332,7 +332,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
return F.allpass_biquad(tensor, sample_rate, central_freq, q) return F.allpass_biquad(tensor, sample_rate, central_freq, q)
...@@ -346,7 +346,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -346,7 +346,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
const_skirt_gain = True const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain) return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
...@@ -361,7 +361,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -361,7 +361,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
const_skirt_gain = True const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain) return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
...@@ -376,7 +376,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -376,7 +376,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
return F.bandreject_biquad(tensor, sample_rate, central_freq, q) return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
...@@ -390,7 +390,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -390,7 +390,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
noise = True noise = True
return F.band_biquad(tensor, sample_rate, central_freq, q, noise) return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
...@@ -405,7 +405,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -405,7 +405,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
noise = False noise = False
return F.band_biquad(tensor, sample_rate, central_freq, q, noise) return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
...@@ -420,8 +420,8 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -420,8 +420,8 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
gain = 40. gain = 40.0
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
return F.treble_biquad(tensor, sample_rate, gain, central_freq, q) return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)
...@@ -435,8 +435,8 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -435,8 +435,8 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
gain = 40. gain = 40.0
central_freq = 1000. central_freq = 1000.0
q = 0.707 q = 0.707
return F.bass_biquad(tensor, sample_rate, gain, central_freq, q) return F.bass_biquad(tensor, sample_rate, gain, central_freq, q)
...@@ -474,8 +474,8 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -474,8 +474,8 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
center_freq = 300. center_freq = 300.0
gain = 1. gain = 1.0
q = 0.707 q = 0.707
return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q) return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)
...@@ -501,39 +501,20 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -501,39 +501,20 @@ class Functional(TempDirMixin, TestBaseMixin):
center = False center = False
norm_vars = False norm_vars = False
a = torch.tensor( a = torch.tensor(
[ [[-1.915875792503357, 1.147700309753418], [1.8242558240890503, 1.3869990110397339]],
[
-1.915875792503357,
1.147700309753418
],
[
1.8242558240890503,
1.3869990110397339
]
],
device=tensor.device, device=tensor.device,
dtype=tensor.dtype dtype=tensor.dtype,
) )
return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars) return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
b = torch.tensor(
[ b = torch.tensor([[-1.8701, -0.1196], [1.8701, 0.1196]])
[
-1.8701,
-0.1196
],
[
1.8701,
0.1196
]
]
)
self._assert_consistency(func, b) self._assert_consistency(func, b)
def test_contrast(self): def test_contrast(self):
waveform = common_utils.get_whitenoise() waveform = common_utils.get_whitenoise()
def func(tensor): def func(tensor):
enhancement_amount = 80. enhancement_amount = 80.0
return F.contrast(tensor, enhancement_amount) return F.contrast(tensor, enhancement_amount)
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
...@@ -552,8 +533,8 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -552,8 +533,8 @@ class Functional(TempDirMixin, TestBaseMixin):
waveform = common_utils.get_whitenoise() waveform = common_utils.get_whitenoise()
def func(tensor): def func(tensor):
gain = 30. gain = 30.0
colour = 50. colour = 50.0
return F.overdrive(tensor, gain, colour) return F.overdrive(tensor, gain, colour)
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
...@@ -582,15 +563,24 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -582,15 +563,24 @@ class Functional(TempDirMixin, TestBaseMixin):
regen = 3.0 regen = 3.0
width = 0.23 width = 0.23
speed = 1.3 speed = 1.3
phase = 60. phase = 60.0
sample_rate = 44100 sample_rate = 44100
return F.flanger(tensor, sample_rate, delay, depth, regen, width, speed, return F.flanger(
phase, modulation='sinusoidal', interpolation='linear') tensor,
sample_rate,
delay,
depth,
regen,
width,
speed,
phase,
modulation="sinusoidal",
interpolation="linear",
)
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_spectral_centroid(self): def test_spectral_centroid(self):
def func(tensor): def func(tensor):
sample_rate = 44100 sample_rate = 44100
n_fft = 400 n_fft = 400
...@@ -605,11 +595,11 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -605,11 +595,11 @@ class Functional(TempDirMixin, TestBaseMixin):
@common_utils.skipIfNoKaldi @common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'): if self.dtype != torch.float32 or self.device != torch.device("cpu"):
raise unittest.SkipTest("Only float32, cpu is supported.") raise unittest.SkipTest("Only float32, cpu is supported.")
def func(tensor): def func(tensor):
sample_rate: float = 44100. sample_rate: float = 44100.0
return F.compute_kaldi_pitch(tensor, sample_rate) return F.compute_kaldi_pitch(tensor, sample_rate)
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
...@@ -630,7 +620,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -630,7 +620,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func_beta(tensor): def func_beta(tensor):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
beta = 6. beta = 6.0
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta) return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)
tensor = common_utils.get_whitenoise(sample_rate=16000) tensor = common_utils.get_whitenoise(sample_rate=16000)
...@@ -663,11 +653,13 @@ class FunctionalFloat32Only(TestBaseMixin): ...@@ -663,11 +653,13 @@ class FunctionalFloat32Only(TestBaseMixin):
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
return F.rnnt_loss(tensor, targets, logit_lengths, target_lengths) return F.rnnt_loss(tensor, targets, logit_lengths, target_lengths)
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], logits = torch.tensor(
[0.1, 0.1, 0.6, 0.1, 0.1], [
[0.1, 0.1, 0.2, 0.8, 0.1]], [
[[0.1, 0.6, 0.1, 0.1, 0.1], [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]],
[0.1, 0.1, 0.2, 0.1, 0.1], [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]],
[0.7, 0.1, 0.2, 0.1, 0.1]]]]) ]
]
)
tensor = logits.to(device=self.device, dtype=torch.float32) tensor = logits.to(device=self.device, dtype=torch.float32)
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
import torch import torch
import torchaudio.kaldi_io as kio import torchaudio.kaldi_io as kio
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
...@@ -9,13 +8,14 @@ class Test_KaldiIO(common_utils.TorchaudioTestCase): ...@@ -9,13 +8,14 @@ class Test_KaldiIO(common_utils.TorchaudioTestCase):
data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]] data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
def _test_helper(self, file_name, expected_data, fn, expected_dtype): def _test_helper(self, file_name, expected_data, fn, expected_dtype):
""" Takes a file_name to the input data and a function fn to extract the """Takes a file_name to the input data and a function fn to extract the
data. It compares the extracted data to the expected_data. The expected_dtype data. It compares the extracted data to the expected_data. The expected_dtype
will be used to check that the extracted data is of the right type. will be used to check that the extracted data is of the right type.
""" """
test_filepath = common_utils.get_asset_path(file_name) test_filepath = common_utils.get_asset_path(file_name)
expected_output = {'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype) expected_output = {
for idx, val in enumerate(expected_data)} "key" + str(idx + 1): torch.tensor(val, dtype=expected_dtype) for idx, val in enumerate(expected_data)
}
for key, vec in fn(test_filepath): for key, vec in fn(test_filepath):
self.assertTrue(key in expected_output) self.assertTrue(key in expected_output)
......
...@@ -10,7 +10,6 @@ from torchaudio_unittest.common_utils import torch_script ...@@ -10,7 +10,6 @@ from torchaudio_unittest.common_utils import torch_script
class TestWav2Letter(common_utils.TorchaudioTestCase): class TestWav2Letter(common_utils.TorchaudioTestCase):
def test_waveform(self): def test_waveform(self):
batch_size = 2 batch_size = 2
num_features = 1 num_features = 1
...@@ -39,10 +38,8 @@ class TestWav2Letter(common_utils.TorchaudioTestCase): ...@@ -39,10 +38,8 @@ class TestWav2Letter(common_utils.TorchaudioTestCase):
class TestMelResNet(common_utils.TorchaudioTestCase): class TestMelResNet(common_utils.TorchaudioTestCase):
def test_waveform(self): def test_waveform(self):
"""Validate the output dimensions of a MelResNet block. """Validate the output dimensions of a MelResNet block."""
"""
n_batch = 2 n_batch = 2
n_time = 200 n_time = 200
...@@ -61,10 +58,8 @@ class TestMelResNet(common_utils.TorchaudioTestCase): ...@@ -61,10 +58,8 @@ class TestMelResNet(common_utils.TorchaudioTestCase):
class TestUpsampleNetwork(common_utils.TorchaudioTestCase): class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
def test_waveform(self): def test_waveform(self):
"""Validate the output dimensions of a UpsampleNetwork block. """Validate the output dimensions of a UpsampleNetwork block."""
"""
upsample_scales = [5, 5, 8] upsample_scales = [5, 5, 8]
n_batch = 2 n_batch = 2
...@@ -79,12 +74,7 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase): ...@@ -79,12 +74,7 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
for upsample_scale in upsample_scales: for upsample_scale in upsample_scales:
total_scale *= upsample_scale total_scale *= upsample_scale
model = UpsampleNetwork(upsample_scales, model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
x = torch.rand(n_batch, n_freq, n_time) x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x) out1, out2 = model(x)
...@@ -94,10 +84,8 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase): ...@@ -94,10 +84,8 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
class TestWaveRNN(common_utils.TorchaudioTestCase): class TestWaveRNN(common_utils.TorchaudioTestCase):
def test_waveform(self): def test_waveform(self):
"""Validate the output dimensions of a WaveRNN model. """Validate the output dimensions of a WaveRNN model."""
"""
upsample_scales = [5, 5, 8] upsample_scales = [5, 5, 8]
n_rnn = 512 n_rnn = 512
...@@ -112,8 +100,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -112,8 +100,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 128 n_hidden = 128
kernel_size = 5 kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, model = WaveRNN(
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time) mels = torch.rand(n_batch, 1, n_freq, n_time)
...@@ -122,8 +111,7 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -122,8 +111,7 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
def test_infer_waveform(self): def test_infer_waveform(self):
"""Validate the output dimensions of a WaveRNN model's infer method. """Validate the output dimensions of a WaveRNN model's infer method."""
"""
upsample_scales = [5, 5, 8] upsample_scales = [5, 5, 8]
n_rnn = 128 n_rnn = 128
...@@ -138,8 +126,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -138,8 +126,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 32 n_hidden = 32
kernel_size = 5 kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, model = WaveRNN(
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
)
x = torch.rand(n_batch, n_freq, n_time) x = torch.rand(n_batch, n_freq, n_time)
lengths = torch.tensor([n_time, n_time // 2]) lengths = torch.tensor([n_time, n_time // 2])
...@@ -165,8 +154,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -165,8 +154,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 32 n_hidden = 32
kernel_size = 5 kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, model = WaveRNN(
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
)
model.eval() model.eval()
x = torch.rand(n_batch, n_freq, n_time) x = torch.rand(n_batch, n_freq, n_time)
torch.random.manual_seed(0) torch.random.manual_seed(0)
...@@ -177,39 +167,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -177,39 +167,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
_ConvTasNetParams = namedtuple( _ConvTasNetParams = namedtuple(
'_ConvTasNetParams', "_ConvTasNetParams",
[ [
'enc_num_feats', "enc_num_feats",
'enc_kernel_size', "enc_kernel_size",
'msk_num_feats', "msk_num_feats",
'msk_num_hidden_feats', "msk_num_hidden_feats",
'msk_kernel_size', "msk_kernel_size",
'msk_num_layers', "msk_num_layers",
'msk_num_stacks', "msk_num_stacks",
] ],
) )
class TestConvTasNet(common_utils.TorchaudioTestCase): class TestConvTasNet(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(
[2, 3], list(
[ itertools.product(
_ConvTasNetParams(128, 40, 128, 256, 3, 7, 2), [2, 3],
_ConvTasNetParams(256, 40, 128, 256, 3, 7, 2), [
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2), _ConvTasNetParams(128, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2), _ConvTasNetParams(256, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2), _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2), _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 256, 3, 7, 2), _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2), _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2), _ConvTasNetParams(512, 40, 256, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 6, 4), _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 4, 6), _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 8, 3), _ConvTasNetParams(512, 40, 128, 512, 3, 6, 4),
_ConvTasNetParams(512, 32, 128, 512, 3, 8, 3), _ConvTasNetParams(512, 40, 128, 512, 3, 4, 6),
_ConvTasNetParams(512, 16, 128, 512, 3, 8, 3), _ConvTasNetParams(512, 40, 128, 512, 3, 8, 3),
], _ConvTasNetParams(512, 32, 128, 512, 3, 8, 3),
))) _ConvTasNetParams(512, 16, 128, 512, 3, 8, 3),
],
)
)
)
def test_paper_configuration(self, num_sources, model_params): def test_paper_configuration(self, num_sources, model_params):
"""ConvTasNet model works on the valid configurations in the paper""" """ConvTasNet model works on the valid configurations in the paper"""
batch_size = 32 batch_size = 32
...@@ -232,7 +226,6 @@ class TestConvTasNet(common_utils.TorchaudioTestCase): ...@@ -232,7 +226,6 @@ class TestConvTasNet(common_utils.TorchaudioTestCase):
class TestDeepSpeech(common_utils.TorchaudioTestCase): class TestDeepSpeech(common_utils.TorchaudioTestCase):
def test_deepspeech(self): def test_deepspeech(self):
n_batch = 2 n_batch = 2
n_feature = 1 n_feature = 1
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .model_test_impl import ( from .model_test_impl import (
Tacotron2EncoderTests, Tacotron2EncoderTests,
Tacotron2DecoderTests, Tacotron2DecoderTests,
......
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .model_test_impl import ( from .model_test_impl import (
Tacotron2EncoderTests, Tacotron2EncoderTests,
Tacotron2DecoderTests, Tacotron2DecoderTests,
......
from typing import Tuple from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio.models import Tacotron2 from torchaudio.models import Tacotron2
...@@ -7,7 +8,6 @@ from torchaudio_unittest.common_utils import TestBaseMixin, torch_script ...@@ -7,7 +8,6 @@ from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class Tacotron2InferenceWrapper(torch.nn.Module): class Tacotron2InferenceWrapper(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.model = model self.model = model
...@@ -17,7 +17,6 @@ class Tacotron2InferenceWrapper(torch.nn.Module): ...@@ -17,7 +17,6 @@ class Tacotron2InferenceWrapper(torch.nn.Module):
class Tacotron2DecoderInferenceWrapper(torch.nn.Module): class Tacotron2DecoderInferenceWrapper(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.model = model self.model = model
...@@ -42,21 +41,18 @@ class TorchscriptConsistencyMixin(TestBaseMixin): ...@@ -42,21 +41,18 @@ class TorchscriptConsistencyMixin(TestBaseMixin):
class Tacotron2EncoderTests(TorchscriptConsistencyMixin): class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
def test_tacotron2_torchscript_consistency(self): def test_tacotron2_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Encoder.""" r"""Validate the torchscript consistency of a Encoder."""
n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 n_batch, n_seq, encoder_embedding_dim = 16, 64, 512
model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, model = (
encoder_n_convolution=3, _Encoder(encoder_embedding_dim=encoder_embedding_dim, encoder_n_convolution=3, encoder_kernel_size=5)
encoder_kernel_size=5).to(self.device).eval() .to(self.device)
.eval()
x = torch.rand(
n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype
)
input_lengths = (
torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
) )
x = torch.rand(n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype)
input_lengths = torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
self._assert_torchscript_consistency(model, (x, input_lengths)) self._assert_torchscript_consistency(model, (x, input_lengths))
def test_encoder_output_shape(self): def test_encoder_output_shape(self):
...@@ -64,23 +60,20 @@ class Tacotron2EncoderTests(TorchscriptConsistencyMixin): ...@@ -64,23 +60,20 @@ class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
that it outputs with a tensor with expected shape. that it outputs with a tensor with expected shape.
""" """
n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 n_batch, n_seq, encoder_embedding_dim = 16, 64, 512
model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, model = (
encoder_n_convolution=3, _Encoder(encoder_embedding_dim=encoder_embedding_dim, encoder_n_convolution=3, encoder_kernel_size=5)
encoder_kernel_size=5).to(self.device).eval() .to(self.device)
.eval()
x = torch.rand(
n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype
)
input_lengths = (
torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
) )
x = torch.rand(n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype)
input_lengths = torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
out = model(x, input_lengths) out = model(x, input_lengths)
assert out.size() == (n_batch, n_seq, encoder_embedding_dim) assert out.size() == (n_batch, n_seq, encoder_embedding_dim)
def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, decoder_max_step=2000, gate_threshold=0.5):
decoder_max_step=2000, gate_threshold=0.5):
model = _Decoder( model = _Decoder(
n_mels=n_mels, n_mels=n_mels,
n_frames_per_step=1, n_frames_per_step=1,
...@@ -101,7 +94,6 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, ...@@ -101,7 +94,6 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512,
class Tacotron2DecoderTests(TorchscriptConsistencyMixin): class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
def test_decoder_torchscript_consistency(self): def test_decoder_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Decoder.""" r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16 n_batch = 16
...@@ -113,17 +105,11 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -113,17 +105,11 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim)
model = model.to(self.device).eval() model = model.to(self.device).eval()
memory = torch.rand( memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device decoder_inputs = torch.rand(n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device)
)
decoder_inputs = torch.rand(
n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
self._assert_torchscript_consistency( self._assert_torchscript_consistency(model, (memory, decoder_inputs, memory_lengths))
model, (memory, decoder_inputs, memory_lengths)
)
def test_decoder_output_shape(self): def test_decoder_output_shape(self):
r"""Feed tensors with specific shape to Tacotron2 Decoder and validate r"""Feed tensors with specific shape to Tacotron2 Decoder and validate
...@@ -138,17 +124,11 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -138,17 +124,11 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim)
model = model.to(self.device).eval() model = model.to(self.device).eval()
memory = torch.rand( memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device decoder_inputs = torch.rand(n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device)
)
decoder_inputs = torch.rand(
n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
mel_specgram, gate_outputs, alignments = model( mel_specgram, gate_outputs, alignments = model(memory, decoder_inputs, memory_lengths)
memory, decoder_inputs, memory_lengths
)
assert mel_specgram.size() == (n_batch, n_mels, n_time_steps) assert mel_specgram.size() == (n_batch, n_mels, n_time_steps)
assert gate_outputs.size() == (n_batch, n_time_steps) assert gate_outputs.size() == (n_batch, n_time_steps)
...@@ -171,9 +151,7 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -171,9 +151,7 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
) )
model = model.to(self.device).eval() model = model.to(self.device).eval()
memory = torch.rand( memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
model_wrapper = Tacotron2DecoderInferenceWrapper(model) model_wrapper = Tacotron2DecoderInferenceWrapper(model)
...@@ -197,17 +175,16 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin): ...@@ -197,17 +175,16 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
) )
model = model.to(self.device).eval() model = model.to(self.device).eval()
memory = torch.rand( memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
mel_specgram, mel_specgram_lengths, gate_outputs, alignments = model.infer( mel_specgram, mel_specgram_lengths, gate_outputs, alignments = model.infer(memory, memory_lengths)
memory, memory_lengths
)
assert len(mel_specgram.size()) == 3 assert len(mel_specgram.size()) == 3
assert mel_specgram.size()[:-1] == (n_batch, n_mels, ) assert mel_specgram.size()[:-1] == (
n_batch,
n_mels,
)
assert mel_specgram.size()[2] == mel_specgram_lengths.max().item() assert mel_specgram.size()[2] == mel_specgram_lengths.max().item()
assert len(mel_specgram_lengths.size()) == 1 assert len(mel_specgram_lengths.size()) == 1
assert mel_specgram_lengths.size()[0] == n_batch assert mel_specgram_lengths.size()[0] == n_batch
...@@ -248,16 +225,9 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): ...@@ -248,16 +225,9 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
class Tacotron2Tests(TorchscriptConsistencyMixin): class Tacotron2Tests(TorchscriptConsistencyMixin):
def _get_inputs(self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int):
def _get_inputs( text = torch.randint(0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device)
self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int text_lengths = max_text_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
):
text = torch.randint(
0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device
)
text_lengths = max_text_length * torch.ones(
(n_batch,), dtype=torch.int32, device=self.device
)
mel_specgram = torch.rand( mel_specgram = torch.rand(
n_batch, n_batch,
n_mels, n_mels,
...@@ -265,9 +235,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -265,9 +235,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
mel_specgram_lengths = max_mel_specgram_length * torch.ones( mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
(n_batch,), dtype=torch.int32, device=self.device
)
return text, text_lengths, mel_specgram, mel_specgram_lengths return text, text_lengths, mel_specgram, mel_specgram_lengths
def test_tacotron2_torchscript_consistency(self): def test_tacotron2_torchscript_consistency(self):
...@@ -278,9 +246,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -278,9 +246,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
max_text_length = 100 max_text_length = 100
model = _get_tacotron2_model(n_mels).to(self.device).eval() model = _get_tacotron2_model(n_mels).to(self.device).eval()
inputs = self._get_inputs( inputs = self._get_inputs(n_mels, n_batch, max_mel_specgram_length, max_text_length)
n_mels, n_batch, max_mel_specgram_length, max_text_length
)
self._assert_torchscript_consistency(model, inputs) self._assert_torchscript_consistency(model, inputs)
...@@ -294,9 +260,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -294,9 +260,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
max_text_length = 100 max_text_length = 100
model = _get_tacotron2_model(n_mels).to(self.device).eval() model = _get_tacotron2_model(n_mels).to(self.device).eval()
inputs = self._get_inputs( inputs = self._get_inputs(n_mels, n_batch, max_mel_specgram_length, max_text_length)
n_mels, n_batch, max_mel_specgram_length, max_text_length
)
mel_out, mel_out_postnet, gate_outputs, alignments = model(*inputs) mel_out, mel_out_postnet, gate_outputs, alignments = model(*inputs)
assert mel_out.size() == (n_batch, n_mels, max_mel_specgram_length) assert mel_out.size() == (n_batch, n_mels, max_mel_specgram_length)
...@@ -315,9 +279,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -315,9 +279,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
max_text_length = 100 max_text_length = 100
model = _get_tacotron2_model(n_mels).to(self.device) model = _get_tacotron2_model(n_mels).to(self.device)
inputs = self._get_inputs( inputs = self._get_inputs(n_mels, n_batch, max_mel_specgram_length, max_text_length)
n_mels, n_batch, max_mel_specgram_length, max_text_length
)
mel_out, mel_out_postnet, gate_outputs, _ = model(*inputs) mel_out, mel_out_postnet, gate_outputs, _ = model(*inputs)
mel_out.sum().backward(retain_graph=True) mel_out.sum().backward(retain_graph=True)
...@@ -325,12 +287,8 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -325,12 +287,8 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
gate_outputs.sum().backward() gate_outputs.sum().backward()
def _get_inference_inputs(self, n_batch: int, max_text_length: int): def _get_inference_inputs(self, n_batch: int, max_text_length: int):
text = torch.randint( text = torch.randint(0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device)
0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device text_lengths = max_text_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
)
text_lengths = max_text_length * torch.ones(
(n_batch,), dtype=torch.int32, device=self.device
)
return text, text_lengths return text, text_lengths
def test_tacotron2_inference_torchscript_consistency(self): def test_tacotron2_inference_torchscript_consistency(self):
...@@ -341,9 +299,11 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -341,9 +299,11 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
decoder_max_step = 200 # make inference more efficient decoder_max_step = 200 # make inference more efficient
gate_threshold = 0.51 # if set to 0.5, the model will only run one step gate_threshold = 0.51 # if set to 0.5, the model will only run one step
model = _get_tacotron2_model( model = (
n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold _get_tacotron2_model(n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold)
).to(self.device).eval() .to(self.device)
.eval()
)
inputs = self._get_inference_inputs(n_batch, max_text_length) inputs = self._get_inference_inputs(n_batch, max_text_length)
model_wrapper = Tacotron2InferenceWrapper(model) model_wrapper = Tacotron2InferenceWrapper(model)
...@@ -360,9 +320,11 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -360,9 +320,11 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
decoder_max_step = 200 # make inference more efficient decoder_max_step = 200 # make inference more efficient
gate_threshold = 0.51 # if set to 0.5, the model will only run one step gate_threshold = 0.51 # if set to 0.5, the model will only run one step
model = _get_tacotron2_model( model = (
n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold _get_tacotron2_model(n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold)
).to(self.device).eval() .to(self.device)
.eval()
)
inputs = self._get_inference_inputs(n_batch, max_text_length) inputs = self._get_inference_inputs(n_batch, max_text_length)
mel_out, mel_specgram_lengths, alignments = model.infer(*inputs) mel_out, mel_specgram_lengths, alignments = model.infer(*inputs)
...@@ -370,7 +332,10 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -370,7 +332,10 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
# There is no guarantee on exactly what max_mel_specgram_length should be # There is no guarantee on exactly what max_mel_specgram_length should be
# We only know that it should be smaller than model.decoder.decoder_max_step # We only know that it should be smaller than model.decoder.decoder_max_step
assert len(mel_out.size()) == 3 assert len(mel_out.size()) == 3
assert mel_out.size()[:2] == (n_batch, n_mels, ) assert mel_out.size()[:2] == (
n_batch,
n_mels,
)
assert mel_out.size()[2] == mel_specgram_lengths.max().item() assert mel_out.size()[2] == mel_specgram_lengths.max().item()
assert len(mel_specgram_lengths.size()) == 1 assert len(mel_specgram_lengths.size()) == 1
assert mel_specgram_lengths.size()[0] == n_batch assert mel_specgram_lengths.size()[0] == n_batch
......
import json import json
import torch import torch
from parameterized import parameterized
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
...@@ -12,8 +13,6 @@ from torchaudio.models.wav2vec2 import ( ...@@ -12,8 +13,6 @@ from torchaudio.models.wav2vec2 import (
from torchaudio.models.wav2vec2.utils import ( from torchaudio.models.wav2vec2.utils import (
import_fairseq_model, import_fairseq_model,
) )
from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
skipIfNoModule, skipIfNoModule,
...@@ -22,63 +21,75 @@ from torchaudio_unittest.common_utils import ( ...@@ -22,63 +21,75 @@ from torchaudio_unittest.common_utils import (
def _load_config(*paths): def _load_config(*paths):
with open(f'{get_asset_path("wav2vec2", "fairseq", *paths)}.json', 'r') as file_: with open(f'{get_asset_path("wav2vec2", "fairseq", *paths)}.json', "r") as file_:
return json.load(file_) return json.load(file_)
def _name_func(testcase_func, i, param): def _name_func(testcase_func, i, param):
return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}' return f"{testcase_func.__name__}_{i}_{param[0][1].__name__}"
# Pretraining models # Pretraining models
WAV2VEC2_BASE = _load_config('wav2vec_small') WAV2VEC2_BASE = _load_config("wav2vec_small")
WAV2VEC2_LARGE = _load_config('libri960_big') WAV2VEC2_LARGE = _load_config("libri960_big")
WAV2VEC2_LARGE_LV60K = _load_config('wav2vec_vox_new') WAV2VEC2_LARGE_LV60K = _load_config("wav2vec_vox_new")
WAV2VEC2_XLSR_53_56K = _load_config('xlsr_53_56k') WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
HUBERT_BASE = _load_config('hubert_base_ls960') HUBERT_BASE = _load_config("hubert_base_ls960")
HUBERT_LARGE_LL60K = _load_config('hubert_large_ll60k') HUBERT_LARGE_LL60K = _load_config("hubert_large_ll60k")
HUBERT_XLARGE_LL60K = _load_config('hubert_xtralarge_ll60k') HUBERT_XLARGE_LL60K = _load_config("hubert_xtralarge_ll60k")
# Finetuning models # Finetuning models
WAV2VEC2_BASE_960H = _load_config('wav2vec_small_960h') WAV2VEC2_BASE_960H = _load_config("wav2vec_small_960h")
WAV2VEC2_LARGE_960H = _load_config('wav2vec_large_960h') WAV2VEC2_LARGE_960H = _load_config("wav2vec_large_960h")
WAV2VEC2_LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h') WAV2VEC2_LARGE_LV60K_960H = _load_config("wav2vec_large_lv60k_960h")
WAV2VEC2_LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h') WAV2VEC2_LARGE_LV60K_SELF_960H = _load_config("wav2vec_large_lv60k_self_960h")
HUBERT_LARGE = _load_config('hubert_large_ll60k_finetune_ls960') HUBERT_LARGE = _load_config("hubert_large_ll60k_finetune_ls960")
HUBERT_XLARGE = _load_config('hubert_xtralarge_ll60k_finetune_ls960') HUBERT_XLARGE = _load_config("hubert_xtralarge_ll60k_finetune_ls960")
# Config and corresponding factory functions # Config and corresponding factory functions
WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand([ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
(WAV2VEC2_BASE, wav2vec2_base), [
(WAV2VEC2_LARGE, wav2vec2_large), (WAV2VEC2_BASE, wav2vec2_base),
(WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k), (WAV2VEC2_LARGE, wav2vec2_large),
(WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k), (WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
], name_func=_name_func) (WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
HUBERT_PRETRAINING_CONFIGS = parameterized.expand([ ],
(HUBERT_BASE, hubert_base), name_func=_name_func,
(HUBERT_LARGE_LL60K, hubert_large), )
(HUBERT_XLARGE_LL60K, hubert_xlarge), HUBERT_PRETRAINING_CONFIGS = parameterized.expand(
], name_func=_name_func) [
ALL_PRETRAINING_CONFIGS = parameterized.expand([ (HUBERT_BASE, hubert_base),
(WAV2VEC2_BASE, wav2vec2_base), (HUBERT_LARGE_LL60K, hubert_large),
(WAV2VEC2_LARGE, wav2vec2_large), (HUBERT_XLARGE_LL60K, hubert_xlarge),
(WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k), ],
(WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k), name_func=_name_func,
(HUBERT_BASE, hubert_base), )
(HUBERT_LARGE_LL60K, hubert_large), ALL_PRETRAINING_CONFIGS = parameterized.expand(
(HUBERT_XLARGE_LL60K, hubert_xlarge), [
], name_func=_name_func) (WAV2VEC2_BASE, wav2vec2_base),
FINETUNING_CONFIGS = parameterized.expand([ (WAV2VEC2_LARGE, wav2vec2_large),
(WAV2VEC2_BASE_960H, wav2vec2_base), (WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
(WAV2VEC2_LARGE_960H, wav2vec2_large), (WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_large_lv60k), (HUBERT_BASE, hubert_base),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k), (HUBERT_LARGE_LL60K, hubert_large),
(HUBERT_LARGE, hubert_large), (HUBERT_XLARGE_LL60K, hubert_xlarge),
(HUBERT_XLARGE, hubert_xlarge), ],
], name_func=_name_func) name_func=_name_func,
)
FINETUNING_CONFIGS = parameterized.expand(
@skipIfNoModule('fairseq') [
(WAV2VEC2_BASE_960H, wav2vec2_base),
(WAV2VEC2_LARGE_960H, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_large_lv60k),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
(HUBERT_LARGE, hubert_large),
(HUBERT_XLARGE, hubert_xlarge),
],
name_func=_name_func,
)
@skipIfNoModule("fairseq")
class TestFairseqIntegration(TorchaudioTestCase): class TestFairseqIntegration(TorchaudioTestCase):
"""Test the process of importing the models from fairseq. """Test the process of importing the models from fairseq.
...@@ -86,9 +97,18 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -86,9 +97,18 @@ class TestFairseqIntegration(TorchaudioTestCase):
1. Models loaded with fairseq cane be imported. 1. Models loaded with fairseq cane be imported.
2. The same model can be recreated without fairseq. 2. The same model can be recreated without fairseq.
""" """
def _get_model(self, config, num_out=None): def _get_model(self, config, num_out=None):
import copy import copy
from omegaconf import OmegaConf
from fairseq.models.hubert.hubert import (
HubertModel,
HubertConfig,
)
from fairseq.models.hubert.hubert_asr import (
HubertCtcConfig,
HubertEncoder,
)
from fairseq.models.wav2vec.wav2vec2 import ( from fairseq.models.wav2vec.wav2vec2 import (
Wav2Vec2Config, Wav2Vec2Config,
Wav2Vec2Model, Wav2Vec2Model,
...@@ -97,32 +117,25 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -97,32 +117,25 @@ class TestFairseqIntegration(TorchaudioTestCase):
Wav2VecEncoder, Wav2VecEncoder,
Wav2Vec2CtcConfig, Wav2Vec2CtcConfig,
) )
from fairseq.models.hubert.hubert_asr import (
HubertCtcConfig,
HubertEncoder,
)
from fairseq.models.hubert.hubert import (
HubertModel,
HubertConfig,
)
from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig
from omegaconf import OmegaConf
if config['_name'] == 'wav2vec_ctc': if config["_name"] == "wav2vec_ctc":
config = copy.deepcopy(config) config = copy.deepcopy(config)
config['w2v_args'] = OmegaConf.create(config['w2v_args']) config["w2v_args"] = OmegaConf.create(config["w2v_args"])
return Wav2VecEncoder(Wav2Vec2CtcConfig(**config), num_out) return Wav2VecEncoder(Wav2Vec2CtcConfig(**config), num_out)
if config['_name'] == 'wav2vec2': if config["_name"] == "wav2vec2":
return Wav2Vec2Model(Wav2Vec2Config(**config)) return Wav2Vec2Model(Wav2Vec2Config(**config))
if config['_name'] == 'hubert_ctc': if config["_name"] == "hubert_ctc":
config = copy.deepcopy(config) config = copy.deepcopy(config)
config['w2v_args'] = OmegaConf.create(config['w2v_args']) config["w2v_args"] = OmegaConf.create(config["w2v_args"])
ctc_cfg = HubertCtcConfig(**config) ctc_cfg = HubertCtcConfig(**config)
return HubertEncoder(ctc_cfg, tgt_dict=range(num_out)) return HubertEncoder(ctc_cfg, tgt_dict=range(num_out))
if config['_name'] == 'hubert': if config["_name"] == "hubert":
dicts = [list(range(i)) for i in config['num_classes']] dicts = [list(range(i)) for i in config["num_classes"]]
return HubertModel( return HubertModel(
HubertConfig(**config['model']), HubertConfig(**config["model"]),
HubertPretrainingConfig(**config['task']), HubertPretrainingConfig(**config["task"]),
dicts, dicts,
) )
raise ValueError(f'Unexpected configuration: {config["_name"]}') raise ValueError(f'Unexpected configuration: {config["_name"]}')
...@@ -139,7 +152,7 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -139,7 +152,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
hyp, _ = imported.extract_features(x) hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1) refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs['layer_results']): for i, (ref, _) in enumerate(refs["layer_results"]):
self.assertEqual(hyp[i], ref.transpose(0, 1)) self.assertEqual(hyp[i], ref.transpose(0, 1))
@HUBERT_PRETRAINING_CONFIGS @HUBERT_PRETRAINING_CONFIGS
...@@ -177,7 +190,13 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -177,7 +190,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
reloaded.eval() reloaded.eval()
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
# Without mask # Without mask
ref, _ = imported(x) ref, _ = imported(x)
hyp, _ = reloaded(x) hyp, _ = reloaded(x)
...@@ -200,14 +219,20 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -200,14 +219,20 @@ class TestFairseqIntegration(TorchaudioTestCase):
# Without mask # Without mask
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
ref = original(x, torch.zeros_like(x))['encoder_out'].transpose(0, 1) ref = original(x, torch.zeros_like(x))["encoder_out"].transpose(0, 1)
hyp, _ = imported(x) hyp, _ = imported(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# With mask # With mask
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
mask = torch.arange(num_frames).expand(batch_size, num_frames) >= lengths[:, None] mask = torch.arange(num_frames).expand(batch_size, num_frames) >= lengths[:, None]
ref = original(x, mask)['encoder_out'].transpose(0, 1) ref = original(x, mask)["encoder_out"].transpose(0, 1)
hyp, output_lengths = imported(x, lengths) hyp, output_lengths = imported(x, lengths)
for i, l in enumerate(output_lengths): for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...]) self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
...@@ -233,7 +258,13 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -233,7 +258,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# With mask # With mask
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref, ref_lengths = imported(x, lengths) ref, ref_lengths = imported(x, lengths)
hyp, hyp_lengths = reloaded(x, lengths) hyp, hyp_lengths = reloaded(x, lengths)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
......
import json import json
import torch import torch
from parameterized import parameterized
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
) )
from torchaudio.models.wav2vec2.utils import import_huggingface_model from torchaudio.models.wav2vec2.utils import import_huggingface_model
from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
skipIfNoModule, skipIfNoModule,
...@@ -17,7 +16,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -17,7 +16,7 @@ from torchaudio_unittest.common_utils import (
def _load_config(*paths): def _load_config(*paths):
with open(f'{get_asset_path("wav2vec2", "huggingface", *paths)}.json', 'r') as file_: with open(f'{get_asset_path("wav2vec2", "huggingface", *paths)}.json', "r") as file_:
return json.load(file_) return json.load(file_)
...@@ -26,36 +25,42 @@ def _name_func(testcase_func, i, param): ...@@ -26,36 +25,42 @@ def _name_func(testcase_func, i, param):
# Pretrained # Pretrained
HF_BASE = _load_config('wav2vec2-base') HF_BASE = _load_config("wav2vec2-base")
HF_LARGE = _load_config('wav2vec2-large') HF_LARGE = _load_config("wav2vec2-large")
HF_LARGE_LV60 = _load_config('wav2vec2-large-lv60') HF_LARGE_LV60 = _load_config("wav2vec2-large-lv60")
HF_LARGE_XLSR_53 = _load_config('wav2vec2-large-xlsr-53') HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53")
HF_BASE_10K_VOXPOPULI = _load_config('wav2vec2-base-10k-voxpopuli') HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli")
# Finetuned # Finetuned
HF_BASE_960H = _load_config('wav2vec2-base-960h') HF_BASE_960H = _load_config("wav2vec2-base-960h")
HF_LARGE_960H = _load_config('wav2vec2-large-960h') HF_LARGE_960H = _load_config("wav2vec2-large-960h")
HF_LARGE_LV60_960H = _load_config('wav2vec2-large-960h-lv60') HF_LARGE_LV60_960H = _load_config("wav2vec2-large-960h-lv60")
HF_LARGE_LV60_SELF_960H = _load_config('wav2vec2-large-960h-lv60-self') HF_LARGE_LV60_SELF_960H = _load_config("wav2vec2-large-960h-lv60-self")
HF_LARGE_XLSR_DE = _load_config('wav2vec2-large-xlsr-53-german') HF_LARGE_XLSR_DE = _load_config("wav2vec2-large-xlsr-53-german")
# Config and corresponding factory functions # Config and corresponding factory functions
PRETRAIN_CONFIGS = parameterized.expand([ PRETRAIN_CONFIGS = parameterized.expand(
(HF_BASE, wav2vec2_base), [
(HF_LARGE, wav2vec2_large), (HF_BASE, wav2vec2_base),
(HF_LARGE_LV60, wav2vec2_large_lv60k), (HF_LARGE, wav2vec2_large),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k), (HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base), (HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
], name_func=_name_func) (HF_BASE_10K_VOXPOPULI, wav2vec2_base),
FINETUNE_CONFIGS = parameterized.expand([ ],
(HF_BASE_960H, wav2vec2_base), name_func=_name_func,
(HF_LARGE_960H, wav2vec2_large), )
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k), FINETUNE_CONFIGS = parameterized.expand(
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k), [
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k), (HF_BASE_960H, wav2vec2_base),
], name_func=_name_func) (HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
@skipIfNoModule('transformers') (HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
],
name_func=_name_func,
)
@skipIfNoModule("transformers")
class TestHFIntegration(TorchaudioTestCase): class TestHFIntegration(TorchaudioTestCase):
"""Test the process of importing the models from Hugging Face Transformers """Test the process of importing the models from Hugging Face Transformers
...@@ -63,6 +68,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -63,6 +68,7 @@ class TestHFIntegration(TorchaudioTestCase):
1. Models loaded with Hugging Face Transformers cane be imported. 1. Models loaded with Hugging Face Transformers cane be imported.
2. The same model can be recreated without Hugging Face Transformers. 2. The same model can be recreated without Hugging Face Transformers.
""" """
def _get_model(self, config): def _get_model(self, config):
# Helper function to avoid importing transformers on module scope. # Helper function to avoid importing transformers on module scope.
# Normally, we use `is_module_available` helper function to check if # Normally, we use `is_module_available` helper function to check if
...@@ -75,9 +81,10 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -75,9 +81,10 @@ class TestHFIntegration(TorchaudioTestCase):
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
) )
if config['architectures'] == ['Wav2Vec2Model']:
if config["architectures"] == ["Wav2Vec2Model"]:
return Wav2Vec2Model(Wav2Vec2Config(**config)) return Wav2Vec2Model(Wav2Vec2Config(**config))
if config['architectures'] == ['Wav2Vec2ForCTC']: if config["architectures"] == ["Wav2Vec2ForCTC"]:
return Wav2Vec2ForCTC(Wav2Vec2Config(**config)) return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}') raise ValueError(f'Unexpected arch: {config["architectures"]}')
...@@ -89,12 +96,12 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -89,12 +96,12 @@ class TestHFIntegration(TorchaudioTestCase):
hyp, _ = imported.feature_extractor(x, None) hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Feature projection # Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1]) x = torch.randn(3, 10, config["conv_dim"][-1])
ref = original.feature_projection(x)[0] ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x) hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Convolutional Positional Encoder # Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size']) x = torch.randn(3, 256, config["hidden_size"])
ref = original.encoder.pos_conv_embed(x) ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x) hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
...@@ -104,7 +111,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -104,7 +111,7 @@ class TestHFIntegration(TorchaudioTestCase):
x = torch.randn(b, l, e) x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l) mask = torch.randn(b, 1, l, l)
ref, = original_(x, attention_mask=mask, output_attentions=False) (ref,) = original_(x, attention_mask=mask, output_attentions=False)
hyp = imported_(x, mask) hyp = imported_(x, mask)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# The whole Encoder Transformer # The whole Encoder Transformer
...@@ -135,7 +142,13 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -135,7 +142,13 @@ class TestHFIntegration(TorchaudioTestCase):
# The whole model with mask # The whole model with mask
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
mask = torch.arange(num_frames).expand(batch_size, num_frames) < lengths[:, None] mask = torch.arange(num_frames).expand(batch_size, num_frames) < lengths[:, None]
ref = original(x, attention_mask=mask).logits ref = original(x, attention_mask=mask).logits
...@@ -167,12 +180,12 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -167,12 +180,12 @@ class TestHFIntegration(TorchaudioTestCase):
hyp, _ = reloaded.feature_extractor(x, None) hyp, _ = reloaded.feature_extractor(x, None)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Feature projection # Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1]) x = torch.randn(3, 10, config["conv_dim"][-1])
ref = imported.encoder.feature_projection(x) ref = imported.encoder.feature_projection(x)
hyp = reloaded.encoder.feature_projection(x) hyp = reloaded.encoder.feature_projection(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Convolutional Positional Encoder # Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size']) x = torch.randn(3, 256, config["hidden_size"])
ref = imported.encoder.transformer.pos_conv_embed(x) ref = imported.encoder.transformer.pos_conv_embed(x)
hyp = reloaded.encoder.transformer.pos_conv_embed(x) hyp = reloaded.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
......
import os import os
from typing import Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Tuple from parameterized import parameterized
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
...@@ -18,7 +18,6 @@ from torchaudio_unittest.common_utils import ( ...@@ -18,7 +18,6 @@ from torchaudio_unittest.common_utils import (
skipIfNoCuda, skipIfNoCuda,
torch_script, torch_script,
) )
from parameterized import parameterized
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10): if TORCH_VERSION >= (1, 10):
...@@ -31,14 +30,17 @@ def _name_func(testcase_func, i, param): ...@@ -31,14 +30,17 @@ def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}" return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([ factory_funcs = parameterized.expand(
(wav2vec2_base, ), [
(wav2vec2_large, ), (wav2vec2_base,),
(wav2vec2_large_lv60k, ), (wav2vec2_large,),
(hubert_base, ), (wav2vec2_large_lv60k,),
(hubert_large, ), (hubert_base,),
(hubert_xlarge, ), (hubert_large,),
], name_func=_name_func) (hubert_xlarge,),
],
name_func=_name_func,
)
class TestWav2Vec2Model(TorchaudioTestCase): class TestWav2Vec2Model(TorchaudioTestCase):
...@@ -49,27 +51,32 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -49,27 +51,32 @@ class TestWav2Vec2Model(TorchaudioTestCase):
torch.manual_seed(0) torch.manual_seed(0)
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
waveforms = torch.randn( waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
batch_size, num_frames, device=device, dtype=dtype)
lengths = torch.randint( lengths = torch.randint(
low=0, high=num_frames, size=[batch_size, ], device=device) low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)
model(waveforms, lengths) model(waveforms, lengths)
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype): def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device("cpu"), dtype)
model = wav2vec2_base(aux_num_out=32) model = wav2vec2_base(aux_num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda @skipIfNoCuda
def test_cuda_smoke_test(self, dtype): def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device("cuda"), dtype)
model = wav2vec2_base(aux_num_out=32) model = wav2vec2_base(aux_num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device("cuda"), dtype)
def _feature_extractor_test(self, model): def _feature_extractor_test(self, model):
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
...@@ -79,7 +86,13 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -79,7 +86,13 @@ class TestWav2Vec2Model(TorchaudioTestCase):
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
# Not providing num_layers returns all the intermediate features from # Not providing num_layers returns all the intermediate features from
# tranformer layers # tranformer layers
...@@ -114,8 +127,8 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -114,8 +127,8 @@ class TestWav2Vec2Model(TorchaudioTestCase):
batch_logits, output_lengths = model(waveforms, input_lengths) batch_logits, output_lengths = model(waveforms, input_lengths)
for i in range(batch_size): for i in range(batch_size):
# Par-sample process without feeding length # Par-sample process without feeding length
single_logit, _ = model(waveforms[i:i + 1, :input_lengths[i]], None) single_logit, _ = model(waveforms[i : i + 1, : input_lengths[i]], None)
batch_logit = batch_logits[i:i + 1, :output_lengths[i]] batch_logit = batch_logits[i : i + 1, : output_lengths[i]]
# Convert to probability so that it's easier to interpretate the diff # Convert to probability so that it's easier to interpretate the diff
single_prob = F.softmax(single_logit, dim=2) single_prob = F.softmax(single_logit, dim=2)
...@@ -125,14 +138,12 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -125,14 +138,12 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@factory_funcs @factory_funcs
def test_pretrain_batch_consistency(self, factory_func): def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close """Results from single process and batched process should be reasonably close"""
"""
self._test_batch_consistency(factory_func()) self._test_batch_consistency(factory_func())
@factory_funcs @factory_funcs
def test_finetune_batch_consistency(self, factory_func): def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close """Results from single process and batched process should be reasonably close"""
"""
self._test_batch_consistency(factory_func(aux_num_out=32)) self._test_batch_consistency(factory_func(aux_num_out=32))
def _test_zero_length(self, model): def _test_zero_length(self, model):
...@@ -163,7 +174,13 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -163,7 +174,13 @@ class TestWav2Vec2Model(TorchaudioTestCase):
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref_out, ref_len = model(waveforms, lengths) ref_out, ref_len = model(waveforms, lengths)
...@@ -177,19 +194,19 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -177,19 +194,19 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@factory_funcs @factory_funcs
def test_pretrain_torchscript(self, factory_func): def test_pretrain_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable""" """Wav2Vec2Model should be scriptable"""
if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true': if factory_func is hubert_xlarge and os.name == "nt" and os.environ.get("CI") == "true":
self.skipTest( self.skipTest(
'hubert_xlarge is known to fail on Windows CI. ' "hubert_xlarge is known to fail on Windows CI. " "See https://github.com/pytorch/pytorch/issues/65776"
'See https://github.com/pytorch/pytorch/issues/65776') )
self._test_torchscript(factory_func()) self._test_torchscript(factory_func())
@factory_funcs @factory_funcs
def test_finetune_torchscript(self, factory_func): def test_finetune_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable""" """Wav2Vec2Model should be scriptable"""
if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true': if factory_func is hubert_xlarge and os.name == "nt" and os.environ.get("CI") == "true":
self.skipTest( self.skipTest(
'hubert_xlarge is known to fail on Windows CI. ' "hubert_xlarge is known to fail on Windows CI. " "See https://github.com/pytorch/pytorch/issues/65776"
'See https://github.com/pytorch/pytorch/issues/65776') )
self._test_torchscript(factory_func(aux_num_out=32)) self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model): def _test_quantize_smoke_test(self, model):
...@@ -198,15 +215,20 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -198,15 +215,20 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook # Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic( quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different # A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module." assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
_, _ = quantized(waveforms, lengths) _, _ = quantized(waveforms, lengths)
...@@ -223,15 +245,20 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -223,15 +245,20 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook # Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic( quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different # A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module." assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref_out, ref_len = quantized(waveforms, lengths) ref_out, ref_len = quantized(waveforms, lengths)
......
import torch import torch
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
class ConformerFloat32CPUTest(ConformerTestImpl, PytorchTestCase): class ConformerFloat32CPUTest(ConformerTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio.prototype import Conformer from torchaudio.prototype import Conformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class ConformerTestImpl(TestBaseMixin): class ConformerTestImpl(TestBaseMixin):
...@@ -24,12 +24,8 @@ class ConformerTestImpl(TestBaseMixin): ...@@ -24,12 +24,8 @@ class ConformerTestImpl(TestBaseMixin):
return conformer return conformer
def _gen_inputs(self, input_dim, batch_size, num_frames): def _gen_inputs(self, input_dim, batch_size, num_frames):
lengths = torch.randint(1, num_frames, (batch_size,)).to( lengths = torch.randint(1, num_frames, (batch_size,)).to(device=self.device, dtype=self.dtype)
device=self.device, dtype=self.dtype input = torch.rand(batch_size, int(lengths.max()), input_dim).to(device=self.device, dtype=self.dtype)
)
input = torch.rand(batch_size, int(lengths.max()), input_dim).to(
device=self.device, dtype=self.dtype
)
return input, lengths return input, lengths
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment