Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .kaldi_compatibility_test_impl import Kaldi
@skipIfNoCuda
class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
device = torch.device("cuda")
@skipIfNoCuda
class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
device = torch.device("cuda")
from parameterized import parameterized
import torch
import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
get_sinusoid,
load_params,
......@@ -21,20 +20,20 @@ class Kaldi(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@skipIfNoExec('apply-cmvn-sliding')
@skipIfNoExec("apply-cmvn-sliding")
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
"cmn_window": 600,
"min_cmn_window": 100,
"center": False,
"norm_vars": False,
}
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"]
kaldi_result = run_kaldi(command, "ark", tensor)
self.assert_equal(result, expected=kaldi_result)
......@@ -43,18 +42,18 @@ class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params('kaldi_test_pitch_args.jsonl'))
@skipIfNoExec('compute-kaldi-pitch-feats')
@parameterized.expand(load_params("kaldi_test_pitch_args.jsonl"))
@skipIfNoExec("compute-kaldi-pitch-feats")
def test_pitch_feats(self, kwargs):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate = kwargs['sample_rate']
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate)
sample_rate = kwargs["sample_rate"]
waveform = get_sinusoid(dtype="float32", sample_rate=sample_rate)
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate)
wave_file = self.get_temp_path('test.wav')
waveform = get_sinusoid(dtype="int16", sample_rate=sample_rate)
wave_file = self.get_temp_path("test.wav")
save_wav(wave_file, waveform, sample_rate)
command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = run_kaldi(command, 'scp', wave_file)
command = ["compute-kaldi-pitch-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result)
from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import Functional, FunctionalComplex
class TestFunctionalCPU(Functional, PytorchTestCase):
device = 'cpu'
device = "cpu"
class TestFunctionalComplexCPU(FunctionalComplex, PytorchTestCase):
device = 'cpu'
device = "cpu"
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import Functional, FunctionalComplex
@skipIfNoCuda
class TestFunctionalCUDA(Functional, PytorchTestCase):
device = 'cuda'
device = "cuda"
@skipIfNoCuda
class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase):
device = 'cuda'
device = "cuda"
......@@ -2,16 +2,15 @@ import unittest
from distutils.version import StrictVersion
import torch
from parameterized import param
import torchaudio.functional as F
from parameterized import param
from torchaudio._internal.module_utils import is_module_available
LIBROSA_AVAILABLE = is_module_available('librosa')
LIBROSA_AVAILABLE = is_module_available("librosa")
if LIBROSA_AVAILABLE:
import numpy as np
import librosa
import numpy as np
from torchaudio_unittest.common_utils import (
......@@ -25,6 +24,7 @@ from torchaudio_unittest.common_utils import (
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class Functional(TestBaseMixin):
"""Test suite for functions in `functional` module."""
dtype = torch.float64
@nested_params([0, 0.99])
......@@ -40,8 +40,8 @@ class Functional(TestBaseMixin):
waveform = get_whitenoise(device=self.device, dtype=self.dtype)
specgram = get_spectrogram(
waveform, n_fft=n_fft, hop_length=hop_length, power=power,
win_length=win_length, window=window)
waveform, n_fft=n_fft, hop_length=hop_length, power=power, win_length=win_length, window=window
)
result = F.griffinlim(
specgram,
......@@ -53,14 +53,16 @@ class Functional(TestBaseMixin):
n_iter=n_iter,
momentum=momentum,
length=waveform.size(1),
rand_init=False)
rand_init=False,
)
expected = librosa.griffinlim(
specgram[0].cpu().numpy(),
n_iter=n_iter,
hop_length=hop_length,
momentum=momentum,
init=None,
length=waveform.size(1))[None, ...]
length=waveform.size(1),
)[None, ...]
self.assertEqual(result, torch.from_numpy(expected), atol=5e-5, rtol=1e-07)
@nested_params(
......@@ -73,24 +75,20 @@ class Functional(TestBaseMixin):
param(n_mels=56, fmin=1900.0, fmax=900.0),
param(n_mels=10, fmin=1900.0, fmax=900.0),
],
[param(norm=n) for n in [None, 'slaney']],
[param(mel_scale=s) for s in ['htk', 'slaney']],
[param(norm=n) for n in [None, "slaney"]],
[param(mel_scale=s) for s in ["htk", "slaney"]],
)
def test_create_mel_fb(self, n_mels=40, sample_rate=22050, n_fft=2048,
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"):
if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")):
self.skipTest('Test is known to fail with older versions of librosa.')
if self.device != 'cpu':
self.skipTest('No need to run this test on CUDA')
def test_create_mel_fb(
self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"
):
if norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
self.skipTest("Test is known to fail with older versions of librosa.")
if self.device != "cpu":
self.skipTest("No need to run this test on CUDA")
expected = librosa.filters.mel(
sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmax=fmax,
fmin=fmin,
htk=mel_scale == "htk",
norm=norm).T
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmax=fmax, fmin=fmin, htk=mel_scale == "htk", norm=norm
).T
result = F.melscale_fbanks(
sample_rate=sample_rate,
n_mels=n_mels,
......@@ -98,7 +96,8 @@ class Functional(TestBaseMixin):
f_min=fmin,
n_freqs=(n_fft // 2 + 1),
norm=norm,
mel_scale=mel_scale)
mel_scale=mel_scale,
)
self.assertEqual(result, torch.from_numpy(expected), atol=7e-5, rtol=1.3e-6)
def test_amplitude_to_DB_power(self):
......@@ -137,18 +136,12 @@ class FunctionalComplex(TestBaseMixin):
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
spec = torch.randn(num_freq, num_frames, device=self.device, dtype=torch.complex128)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
device=self.device,
dtype=torch.float64)[..., None]
phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, device=self.device, dtype=torch.float64)[
..., None
]
stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
expected_stretched = librosa.phase_vocoder(
spec.cpu().numpy(),
rate=rate,
hop_length=hop_length)
expected_stretched = librosa.phase_vocoder(spec.cpu().numpy(), rate=rate, hop_length=hop_length)
self.assertEqual(stretched, torch.from_numpy(expected_stretched))
import torch
import torchaudio.functional as F
from torchaudio_unittest.common_utils import (
skipIfNoSox,
skipIfNoExec,
......@@ -15,10 +14,10 @@ from torchaudio_unittest.common_utils import (
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoExec("sox")
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def run_sox_effect(self, input_file, effect):
output_file = self.get_temp_path('expected.wav')
output_file = self.get_temp_path("expected.wav")
sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect])
return load_wav(output_file)
......@@ -28,29 +27,31 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def get_whitenoise(self, sample_rate=8000):
noise = get_whitenoise(
sample_rate=sample_rate, duration=3, scale_factor=0.9,
sample_rate=sample_rate,
duration=3,
scale_factor=0.9,
)
path = self.get_temp_path("whitenoise.wav")
save_wav(path, noise, sample_rate)
return noise, path
def test_gain(self):
path = get_asset_path('steam-train-whistle-daniel_simon.wav')
path = get_asset_path("steam-train-whistle-daniel_simon.wav")
data, _ = load_wav(path)
result = F.gain(data, 3)
self.assert_sox_effect(result, path, ['gain', 3])
self.assert_sox_effect(result, path, ["gain", 3])
def test_dither(self):
path = get_asset_path('steam-train-whistle-daniel_simon.wav')
path = get_asset_path("steam-train-whistle-daniel_simon.wav")
data, _ = load_wav(path)
result = F.dither(data)
self.assert_sox_effect(result, path, ['dither'])
self.assert_sox_effect(result, path, ["dither"])
def test_dither_noise(self):
path = get_asset_path('steam-train-whistle-daniel_simon.wav')
path = get_asset_path("steam-train-whistle-daniel_simon.wav")
data, _ = load_wav(path)
result = F.dither(data, noise_shaping=True)
self.assert_sox_effect(result, path, ['dither', '-s'], atol=1.5e-4)
self.assert_sox_effect(result, path, ["dither", "-s"], atol=1.5e-4)
def test_lowpass(self):
cutoff_freq = 3000
......@@ -58,7 +59,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.lowpass_biquad(data, sample_rate, cutoff_freq)
self.assert_sox_effect(result, path, ['lowpass', cutoff_freq], atol=1.5e-4)
self.assert_sox_effect(result, path, ["lowpass", cutoff_freq], atol=1.5e-4)
def test_highpass(self):
cutoff_freq = 2000
......@@ -66,7 +67,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.highpass_biquad(data, sample_rate, cutoff_freq)
self.assert_sox_effect(result, path, ['highpass', cutoff_freq], atol=1.5e-4)
self.assert_sox_effect(result, path, ["highpass", cutoff_freq], atol=1.5e-4)
def test_allpass(self):
central_freq = 1000
......@@ -75,7 +76,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.allpass_biquad(data, sample_rate, central_freq, q)
self.assert_sox_effect(result, path, ['allpass', central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["allpass", central_freq, f"{q}q"])
def test_bandpass_with_csg(self):
central_freq = 1000
......@@ -85,7 +86,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain)
self.assert_sox_effect(result, path, ['bandpass', '-c', central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["bandpass", "-c", central_freq, f"{q}q"])
def test_bandpass_without_csg(self):
central_freq = 1000
......@@ -95,7 +96,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain)
self.assert_sox_effect(result, path, ['bandpass', central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["bandpass", central_freq, f"{q}q"])
def test_bandreject(self):
central_freq = 1000
......@@ -104,7 +105,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.bandreject_biquad(data, sample_rate, central_freq, q)
self.assert_sox_effect(result, path, ['bandreject', central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["bandreject", central_freq, f"{q}q"])
def test_band_with_noise(self):
central_freq = 1000
......@@ -114,7 +115,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.band_biquad(data, sample_rate, central_freq, q, noise)
self.assert_sox_effect(result, path, ['band', '-n', central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["band", "-n", central_freq, f"{q}q"])
def test_band_without_noise(self):
central_freq = 1000
......@@ -124,7 +125,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.band_biquad(data, sample_rate, central_freq, q, noise)
self.assert_sox_effect(result, path, ['band', central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["band", central_freq, f"{q}q"])
def test_treble(self):
central_freq = 1000
......@@ -134,7 +135,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.treble_biquad(data, sample_rate, gain, central_freq, q)
self.assert_sox_effect(result, path, ['treble', gain, central_freq, f'{q}q'])
self.assert_sox_effect(result, path, ["treble", gain, central_freq, f"{q}q"])
def test_bass(self):
central_freq = 1000
......@@ -144,26 +145,26 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.bass_biquad(data, sample_rate, gain, central_freq, q)
self.assert_sox_effect(result, path, ['bass', gain, central_freq, f'{q}q'], atol=1.5e-4)
self.assert_sox_effect(result, path, ["bass", gain, central_freq, f"{q}q"], atol=1.5e-4)
def test_deemph(self):
sample_rate = 44100
data, path = self.get_whitenoise(sample_rate)
result = F.deemph_biquad(data, sample_rate)
self.assert_sox_effect(result, path, ['deemph'])
self.assert_sox_effect(result, path, ["deemph"])
def test_riaa(self):
sample_rate = 44100
data, path = self.get_whitenoise(sample_rate)
result = F.riaa_biquad(data, sample_rate)
self.assert_sox_effect(result, path, ['riaa'])
self.assert_sox_effect(result, path, ["riaa"])
def test_contrast(self):
enhancement_amount = 80.
enhancement_amount = 80.0
data, path = self.get_whitenoise()
result = F.contrast(data, enhancement_amount)
self.assert_sox_effect(result, path, ['contrast', enhancement_amount])
self.assert_sox_effect(result, path, ["contrast", enhancement_amount])
def test_dcshift_with_limiter(self):
shift = 0.5
......@@ -171,14 +172,14 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise()
result = F.dcshift(data, shift, limiter_gain)
self.assert_sox_effect(result, path, ['dcshift', shift, limiter_gain])
self.assert_sox_effect(result, path, ["dcshift", shift, limiter_gain])
def test_dcshift_without_limiter(self):
shift = 0.6
data, path = self.get_whitenoise()
result = F.dcshift(data, shift)
self.assert_sox_effect(result, path, ['dcshift', shift])
self.assert_sox_effect(result, path, ["dcshift", shift])
def test_overdrive(self):
gain = 30
......@@ -186,7 +187,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise()
result = F.overdrive(data, gain, colour)
self.assert_sox_effect(result, path, ['overdrive', gain, colour])
self.assert_sox_effect(result, path, ["overdrive", gain, colour])
def test_phaser_sine(self):
gain_in = 0.5
......@@ -198,7 +199,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)
self.assert_sox_effect(result, path, ['phaser', gain_in, gain_out, delay_ms, decay, speed, '-s'])
self.assert_sox_effect(result, path, ["phaser", gain_in, gain_out, delay_ms, decay, speed, "-s"])
def test_phaser_triangle(self):
gain_in = 0.5
......@@ -210,7 +211,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=False)
self.assert_sox_effect(result, path, ['phaser', gain_in, gain_out, delay_ms, decay, speed, '-t'])
self.assert_sox_effect(result, path, ["phaser", gain_in, gain_out, delay_ms, decay, speed, "-t"])
def test_flanger_triangle_linear(self):
delay = 0.6
......@@ -223,10 +224,11 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase,
modulation='triangular', interpolation='linear')
data, sample_rate, delay, depth, regen, width, speed, phase, modulation="triangular", interpolation="linear"
)
self.assert_sox_effect(
result, path, ['flanger', delay, depth, regen, width, speed, 'triangle', phase, 'linear'])
result, path, ["flanger", delay, depth, regen, width, speed, "triangle", phase, "linear"]
)
def test_flanger_triangle_quad(self):
delay = 0.8
......@@ -239,10 +241,20 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase,
modulation='triangular', interpolation='quadratic')
data,
sample_rate,
delay,
depth,
regen,
width,
speed,
phase,
modulation="triangular",
interpolation="quadratic",
)
self.assert_sox_effect(
result, path, ['flanger', delay, depth, regen, width, speed, 'triangle', phase, 'quadratic'])
result, path, ["flanger", delay, depth, regen, width, speed, "triangle", phase, "quadratic"]
)
def test_flanger_sine_linear(self):
delay = 0.8
......@@ -255,10 +267,9 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase,
modulation='sinusoidal', interpolation='linear')
self.assert_sox_effect(
result, path, ['flanger', delay, depth, regen, width, speed, 'sine', phase, 'linear'])
data, sample_rate, delay, depth, regen, width, speed, phase, modulation="sinusoidal", interpolation="linear"
)
self.assert_sox_effect(result, path, ["flanger", delay, depth, regen, width, speed, "sine", phase, "linear"])
def test_flanger_sine_quad(self):
delay = 0.9
......@@ -271,10 +282,18 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.flanger(
data, sample_rate, delay, depth, regen, width, speed, phase,
modulation='sinusoidal', interpolation='quadratic')
self.assert_sox_effect(
result, path, ['flanger', delay, depth, regen, width, speed, 'sine', phase, 'quadratic'])
data,
sample_rate,
delay,
depth,
regen,
width,
speed,
phase,
modulation="sinusoidal",
interpolation="quadratic",
)
self.assert_sox_effect(result, path, ["flanger", delay, depth, regen, width, speed, "sine", phase, "quadratic"])
def test_equalizer(self):
center_freq = 300
......@@ -284,7 +303,7 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise(sample_rate)
result = F.equalizer_biquad(data, sample_rate, center_freq, gain, q)
self.assert_sox_effect(result, path, ['equalizer', center_freq, q, gain])
self.assert_sox_effect(result, path, ["equalizer", center_freq, q, gain])
def test_perf_biquad_filtering(self):
b0 = 0.4
......@@ -296,4 +315,4 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, path = self.get_whitenoise()
result = F.lfilter(data, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]))
self.assert_sox_effect(result, path, ['biquad', b0, b1, b2, a0, a1, a2])
self.assert_sox_effect(result, path, ["biquad", b0, b1, b2, a0, a1, a2])
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
device = torch.device("cpu")
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
@skipIfNoCuda
class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
device = torch.device("cuda")
@skipIfNoCuda
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
device = torch.device("cuda")
......@@ -3,7 +3,6 @@ import unittest
import torch
import torchaudio.functional as F
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
......@@ -15,6 +14,7 @@ from torchaudio_unittest.common_utils import (
class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
ts_func = torch_script(func)
......@@ -79,7 +79,7 @@ class Functional(TempDirMixin, TestBaseMixin):
ws = 400
hop = 200
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
power = 2.0
momentum = 0.99
n_iter = 32
length = 1000
......@@ -110,8 +110,8 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, waveform)
def test_melscale_fbanks(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_):
n_stft = 100
......@@ -126,8 +126,8 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, dummy)
def test_linear_fbanks(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_):
n_stft = 100
......@@ -153,16 +153,16 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_DB_to_amplitude(self):
def func(tensor):
ref = 1.
power = 1.
ref = 1.0
power = 1.0
return F.DB_to_amplitude(tensor, ref, power)
tensor = torch.rand((1, 100))
self._assert_consistency(func, tensor)
def test_create_dct(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_):
n_mfcc = 40
......@@ -192,7 +192,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_mask_along_axis(self):
def func(tensor):
mask_param = 100
mask_value = 30.
mask_value = 30.0
axis = 2
return F.mask_along_axis(tensor, mask_param, mask_value, axis)
......@@ -202,7 +202,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_mask_along_axis_iid(self):
def func(tensor):
mask_param = 100
mask_value = 30.
mask_value = 30.0
axis = 2
return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
......@@ -219,21 +219,21 @@ class Functional(TempDirMixin, TestBaseMixin):
def test_dither_TPDF(self):
def func(tensor):
return F.dither(tensor, 'TPDF')
return F.dither(tensor, "TPDF")
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
def test_dither_RPDF(self):
def func(tensor):
return F.dither(tensor, 'RPDF')
return F.dither(tensor, "RPDF")
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
def test_dither_GPDF(self):
def func(tensor):
return F.dither(tensor, 'GPDF')
return F.dither(tensor, "GPDF")
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
......@@ -306,7 +306,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
cutoff_freq = 3000.
cutoff_freq = 3000.0
return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform)
......@@ -319,7 +319,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
cutoff_freq = 2000.
cutoff_freq = 2000.0
return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform)
......@@ -332,7 +332,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
central_freq = 1000.
central_freq = 1000.0
q = 0.707
return F.allpass_biquad(tensor, sample_rate, central_freq, q)
......@@ -346,7 +346,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
central_freq = 1000.
central_freq = 1000.0
q = 0.707
const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
......@@ -361,7 +361,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
central_freq = 1000.
central_freq = 1000.0
q = 0.707
const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
......@@ -376,7 +376,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
central_freq = 1000.
central_freq = 1000.0
q = 0.707
return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
......@@ -390,7 +390,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
central_freq = 1000.
central_freq = 1000.0
q = 0.707
noise = True
return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
......@@ -405,7 +405,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
central_freq = 1000.
central_freq = 1000.0
q = 0.707
noise = False
return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
......@@ -420,8 +420,8 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
gain = 40.
central_freq = 1000.
gain = 40.0
central_freq = 1000.0
q = 0.707
return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)
......@@ -435,8 +435,8 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
gain = 40.
central_freq = 1000.
gain = 40.0
central_freq = 1000.0
q = 0.707
return F.bass_biquad(tensor, sample_rate, gain, central_freq, q)
......@@ -474,8 +474,8 @@ class Functional(TempDirMixin, TestBaseMixin):
def func(tensor):
sample_rate = 44100
center_freq = 300.
gain = 1.
center_freq = 300.0
gain = 1.0
q = 0.707
return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)
......@@ -501,39 +501,20 @@ class Functional(TempDirMixin, TestBaseMixin):
center = False
norm_vars = False
a = torch.tensor(
[
[
-1.915875792503357,
1.147700309753418
],
[
1.8242558240890503,
1.3869990110397339
]
],
[[-1.915875792503357, 1.147700309753418], [1.8242558240890503, 1.3869990110397339]],
device=tensor.device,
dtype=tensor.dtype
dtype=tensor.dtype,
)
return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
b = torch.tensor(
[
[
-1.8701,
-0.1196
],
[
1.8701,
0.1196
]
]
)
b = torch.tensor([[-1.8701, -0.1196], [1.8701, 0.1196]])
self._assert_consistency(func, b)
def test_contrast(self):
waveform = common_utils.get_whitenoise()
def func(tensor):
enhancement_amount = 80.
enhancement_amount = 80.0
return F.contrast(tensor, enhancement_amount)
self._assert_consistency(func, waveform)
......@@ -552,8 +533,8 @@ class Functional(TempDirMixin, TestBaseMixin):
waveform = common_utils.get_whitenoise()
def func(tensor):
gain = 30.
colour = 50.
gain = 30.0
colour = 50.0
return F.overdrive(tensor, gain, colour)
self._assert_consistency(func, waveform)
......@@ -582,15 +563,24 @@ class Functional(TempDirMixin, TestBaseMixin):
regen = 3.0
width = 0.23
speed = 1.3
phase = 60.
phase = 60.0
sample_rate = 44100
return F.flanger(tensor, sample_rate, delay, depth, regen, width, speed,
phase, modulation='sinusoidal', interpolation='linear')
return F.flanger(
tensor,
sample_rate,
delay,
depth,
regen,
width,
speed,
phase,
modulation="sinusoidal",
interpolation="linear",
)
self._assert_consistency(func, waveform)
def test_spectral_centroid(self):
def func(tensor):
sample_rate = 44100
n_fft = 400
......@@ -605,11 +595,11 @@ class Functional(TempDirMixin, TestBaseMixin):
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
if self.dtype != torch.float32 or self.device != torch.device("cpu"):
raise unittest.SkipTest("Only float32, cpu is supported.")
def func(tensor):
sample_rate: float = 44100.
sample_rate: float = 44100.0
return F.compute_kaldi_pitch(tensor, sample_rate)
tensor = common_utils.get_whitenoise(sample_rate=44100)
......@@ -630,7 +620,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def func_beta(tensor):
sr1, sr2 = 16000, 8000
beta = 6.
beta = 6.0
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)
tensor = common_utils.get_whitenoise(sample_rate=16000)
......@@ -663,11 +653,13 @@ class FunctionalFloat32Only(TestBaseMixin):
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
return F.rnnt_loss(tensor, targets, logit_lengths, target_lengths)
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
logits = torch.tensor(
[
[
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]],
]
]
)
tensor = logits.to(device=self.device, dtype=torch.float32)
self._assert_consistency(func, tensor)
import torch
import torchaudio.kaldi_io as kio
from torchaudio_unittest import common_utils
......@@ -9,13 +8,14 @@ class Test_KaldiIO(common_utils.TorchaudioTestCase):
data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
def _test_helper(self, file_name, expected_data, fn, expected_dtype):
""" Takes a file_name to the input data and a function fn to extract the
"""Takes a file_name to the input data and a function fn to extract the
data. It compares the extracted data to the expected_data. The expected_dtype
will be used to check that the extracted data is of the right type.
"""
test_filepath = common_utils.get_asset_path(file_name)
expected_output = {'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype)
for idx, val in enumerate(expected_data)}
expected_output = {
"key" + str(idx + 1): torch.tensor(val, dtype=expected_dtype) for idx, val in enumerate(expected_data)
}
for key, vec in fn(test_filepath):
self.assertTrue(key in expected_output)
......
......@@ -10,7 +10,6 @@ from torchaudio_unittest.common_utils import torch_script
class TestWav2Letter(common_utils.TorchaudioTestCase):
def test_waveform(self):
batch_size = 2
num_features = 1
......@@ -39,10 +38,8 @@ class TestWav2Letter(common_utils.TorchaudioTestCase):
class TestMelResNet(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a MelResNet block.
"""
"""Validate the output dimensions of a MelResNet block."""
n_batch = 2
n_time = 200
......@@ -61,10 +58,8 @@ class TestMelResNet(common_utils.TorchaudioTestCase):
class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a UpsampleNetwork block.
"""
"""Validate the output dimensions of a UpsampleNetwork block."""
upsample_scales = [5, 5, 8]
n_batch = 2
......@@ -79,12 +74,7 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
model = UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
x = torch.rand(n_batch, n_freq, n_time)
out1, out2 = model(x)
......@@ -94,10 +84,8 @@ class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
class TestWaveRNN(common_utils.TorchaudioTestCase):
def test_waveform(self):
"""Validate the output dimensions of a WaveRNN model.
"""
"""Validate the output dimensions of a WaveRNN model."""
upsample_scales = [5, 5, 8]
n_rnn = 512
......@@ -112,8 +100,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 128
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
model = WaveRNN(
upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
)
x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
mels = torch.rand(n_batch, 1, n_freq, n_time)
......@@ -122,8 +111,7 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
def test_infer_waveform(self):
"""Validate the output dimensions of a WaveRNN model's infer method.
"""
"""Validate the output dimensions of a WaveRNN model's infer method."""
upsample_scales = [5, 5, 8]
n_rnn = 128
......@@ -138,8 +126,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 32
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
model = WaveRNN(
upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
)
x = torch.rand(n_batch, n_freq, n_time)
lengths = torch.tensor([n_time, n_time // 2])
......@@ -165,8 +154,9 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
n_hidden = 32
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
model = WaveRNN(
upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output
)
model.eval()
x = torch.rand(n_batch, n_freq, n_time)
torch.random.manual_seed(0)
......@@ -177,39 +167,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
_ConvTasNetParams = namedtuple(
'_ConvTasNetParams',
"_ConvTasNetParams",
[
'enc_num_feats',
'enc_kernel_size',
'msk_num_feats',
'msk_num_hidden_feats',
'msk_kernel_size',
'msk_num_layers',
'msk_num_stacks',
]
"enc_num_feats",
"enc_kernel_size",
"msk_num_feats",
"msk_num_hidden_feats",
"msk_kernel_size",
"msk_num_layers",
"msk_num_stacks",
],
)
class TestConvTasNet(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[2, 3],
[
_ConvTasNetParams(128, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(256, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 6, 4),
_ConvTasNetParams(512, 40, 128, 512, 3, 4, 6),
_ConvTasNetParams(512, 40, 128, 512, 3, 8, 3),
_ConvTasNetParams(512, 32, 128, 512, 3, 8, 3),
_ConvTasNetParams(512, 16, 128, 512, 3, 8, 3),
],
)))
@parameterized.expand(
list(
itertools.product(
[2, 3],
[
_ConvTasNetParams(128, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(256, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 256, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 256, 512, 3, 7, 2),
_ConvTasNetParams(512, 40, 128, 512, 3, 6, 4),
_ConvTasNetParams(512, 40, 128, 512, 3, 4, 6),
_ConvTasNetParams(512, 40, 128, 512, 3, 8, 3),
_ConvTasNetParams(512, 32, 128, 512, 3, 8, 3),
_ConvTasNetParams(512, 16, 128, 512, 3, 8, 3),
],
)
)
)
def test_paper_configuration(self, num_sources, model_params):
"""ConvTasNet model works on the valid configurations in the paper"""
batch_size = 32
......@@ -232,7 +226,6 @@ class TestConvTasNet(common_utils.TorchaudioTestCase):
class TestDeepSpeech(common_utils.TorchaudioTestCase):
def test_deepspeech(self):
n_batch = 2
n_feature = 1
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .model_test_impl import (
Tacotron2EncoderTests,
Tacotron2DecoderTests,
......
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .model_test_impl import (
Tacotron2EncoderTests,
Tacotron2DecoderTests,
......
from typing import Tuple
import torch
from torch import Tensor
from torchaudio.models import Tacotron2
......@@ -7,7 +8,6 @@ from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class Tacotron2InferenceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
......@@ -17,7 +17,6 @@ class Tacotron2InferenceWrapper(torch.nn.Module):
class Tacotron2DecoderInferenceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
......@@ -42,21 +41,18 @@ class TorchscriptConsistencyMixin(TestBaseMixin):
class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
def test_tacotron2_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Encoder."""
n_batch, n_seq, encoder_embedding_dim = 16, 64, 512
model = _Encoder(encoder_embedding_dim=encoder_embedding_dim,
encoder_n_convolution=3,
encoder_kernel_size=5).to(self.device).eval()
x = torch.rand(
n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype
)
input_lengths = (
torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
model = (
_Encoder(encoder_embedding_dim=encoder_embedding_dim, encoder_n_convolution=3, encoder_kernel_size=5)
.to(self.device)
.eval()
)
x = torch.rand(n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype)
input_lengths = torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
self._assert_torchscript_consistency(model, (x, input_lengths))
def test_encoder_output_shape(self):
......@@ -64,23 +60,20 @@ class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
that it outputs with a tensor with expected shape.
"""
n_batch, n_seq, encoder_embedding_dim = 16, 64, 512
model = _Encoder(encoder_embedding_dim=encoder_embedding_dim,
encoder_n_convolution=3,
encoder_kernel_size=5).to(self.device).eval()
x = torch.rand(
n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype
)
input_lengths = (
torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
model = (
_Encoder(encoder_embedding_dim=encoder_embedding_dim, encoder_n_convolution=3, encoder_kernel_size=5)
.to(self.device)
.eval()
)
x = torch.rand(n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype)
input_lengths = torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq
out = model(x, input_lengths)
assert out.size() == (n_batch, n_seq, encoder_embedding_dim)
def _get_decoder_model(n_mels=80, encoder_embedding_dim=512,
decoder_max_step=2000, gate_threshold=0.5):
def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, decoder_max_step=2000, gate_threshold=0.5):
model = _Decoder(
n_mels=n_mels,
n_frames_per_step=1,
......@@ -101,7 +94,6 @@ def _get_decoder_model(n_mels=80, encoder_embedding_dim=512,
class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
def test_decoder_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
......@@ -113,17 +105,11 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim)
model = model.to(self.device).eval()
memory = torch.rand(
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
decoder_inputs = torch.rand(
n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device
)
memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
decoder_inputs = torch.rand(n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
self._assert_torchscript_consistency(
model, (memory, decoder_inputs, memory_lengths)
)
self._assert_torchscript_consistency(model, (memory, decoder_inputs, memory_lengths))
def test_decoder_output_shape(self):
r"""Feed tensors with specific shape to Tacotron2 Decoder and validate
......@@ -138,17 +124,11 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim)
model = model.to(self.device).eval()
memory = torch.rand(
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
decoder_inputs = torch.rand(
n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device
)
memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
decoder_inputs = torch.rand(n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
mel_specgram, gate_outputs, alignments = model(
memory, decoder_inputs, memory_lengths
)
mel_specgram, gate_outputs, alignments = model(memory, decoder_inputs, memory_lengths)
assert mel_specgram.size() == (n_batch, n_mels, n_time_steps)
assert gate_outputs.size() == (n_batch, n_time_steps)
......@@ -171,9 +151,7 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
)
model = model.to(self.device).eval()
memory = torch.rand(
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
model_wrapper = Tacotron2DecoderInferenceWrapper(model)
......@@ -197,17 +175,16 @@ class Tacotron2DecoderTests(TorchscriptConsistencyMixin):
)
model = model.to(self.device).eval()
memory = torch.rand(
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
memory = torch.rand(n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
mel_specgram, mel_specgram_lengths, gate_outputs, alignments = model.infer(
memory, memory_lengths
)
mel_specgram, mel_specgram_lengths, gate_outputs, alignments = model.infer(memory, memory_lengths)
assert len(mel_specgram.size()) == 3
assert mel_specgram.size()[:-1] == (n_batch, n_mels, )
assert mel_specgram.size()[:-1] == (
n_batch,
n_mels,
)
assert mel_specgram.size()[2] == mel_specgram_lengths.max().item()
assert len(mel_specgram_lengths.size()) == 1
assert mel_specgram_lengths.size()[0] == n_batch
......@@ -248,16 +225,9 @@ def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
class Tacotron2Tests(TorchscriptConsistencyMixin):
def _get_inputs(
self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int
):
text = torch.randint(
0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device
)
text_lengths = max_text_length * torch.ones(
(n_batch,), dtype=torch.int32, device=self.device
)
def _get_inputs(self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int):
text = torch.randint(0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device)
text_lengths = max_text_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
mel_specgram = torch.rand(
n_batch,
n_mels,
......@@ -265,9 +235,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
dtype=self.dtype,
device=self.device,
)
mel_specgram_lengths = max_mel_specgram_length * torch.ones(
(n_batch,), dtype=torch.int32, device=self.device
)
mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
return text, text_lengths, mel_specgram, mel_specgram_lengths
def test_tacotron2_torchscript_consistency(self):
......@@ -278,9 +246,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
max_text_length = 100
model = _get_tacotron2_model(n_mels).to(self.device).eval()
inputs = self._get_inputs(
n_mels, n_batch, max_mel_specgram_length, max_text_length
)
inputs = self._get_inputs(n_mels, n_batch, max_mel_specgram_length, max_text_length)
self._assert_torchscript_consistency(model, inputs)
......@@ -294,9 +260,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
max_text_length = 100
model = _get_tacotron2_model(n_mels).to(self.device).eval()
inputs = self._get_inputs(
n_mels, n_batch, max_mel_specgram_length, max_text_length
)
inputs = self._get_inputs(n_mels, n_batch, max_mel_specgram_length, max_text_length)
mel_out, mel_out_postnet, gate_outputs, alignments = model(*inputs)
assert mel_out.size() == (n_batch, n_mels, max_mel_specgram_length)
......@@ -315,9 +279,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
max_text_length = 100
model = _get_tacotron2_model(n_mels).to(self.device)
inputs = self._get_inputs(
n_mels, n_batch, max_mel_specgram_length, max_text_length
)
inputs = self._get_inputs(n_mels, n_batch, max_mel_specgram_length, max_text_length)
mel_out, mel_out_postnet, gate_outputs, _ = model(*inputs)
mel_out.sum().backward(retain_graph=True)
......@@ -325,12 +287,8 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
gate_outputs.sum().backward()
def _get_inference_inputs(self, n_batch: int, max_text_length: int):
text = torch.randint(
0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device
)
text_lengths = max_text_length * torch.ones(
(n_batch,), dtype=torch.int32, device=self.device
)
text = torch.randint(0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device)
text_lengths = max_text_length * torch.ones((n_batch,), dtype=torch.int32, device=self.device)
return text, text_lengths
def test_tacotron2_inference_torchscript_consistency(self):
......@@ -341,9 +299,11 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
decoder_max_step = 200 # make inference more efficient
gate_threshold = 0.51 # if set to 0.5, the model will only run one step
model = _get_tacotron2_model(
n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold
).to(self.device).eval()
model = (
_get_tacotron2_model(n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold)
.to(self.device)
.eval()
)
inputs = self._get_inference_inputs(n_batch, max_text_length)
model_wrapper = Tacotron2InferenceWrapper(model)
......@@ -360,9 +320,11 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
decoder_max_step = 200 # make inference more efficient
gate_threshold = 0.51 # if set to 0.5, the model will only run one step
model = _get_tacotron2_model(
n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold
).to(self.device).eval()
model = (
_get_tacotron2_model(n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold)
.to(self.device)
.eval()
)
inputs = self._get_inference_inputs(n_batch, max_text_length)
mel_out, mel_specgram_lengths, alignments = model.infer(*inputs)
......@@ -370,7 +332,10 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
# There is no guarantee on exactly what max_mel_specgram_length should be
# We only know that it should be smaller than model.decoder.decoder_max_step
assert len(mel_out.size()) == 3
assert mel_out.size()[:2] == (n_batch, n_mels, )
assert mel_out.size()[:2] == (
n_batch,
n_mels,
)
assert mel_out.size()[2] == mel_specgram_lengths.max().item()
assert len(mel_specgram_lengths.size()) == 1
assert mel_specgram_lengths.size()[0] == n_batch
......
import json
import torch
from parameterized import parameterized
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
......@@ -12,8 +13,6 @@ from torchaudio.models.wav2vec2 import (
from torchaudio.models.wav2vec2.utils import (
import_fairseq_model,
)
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfNoModule,
......@@ -22,63 +21,75 @@ from torchaudio_unittest.common_utils import (
def _load_config(*paths):
with open(f'{get_asset_path("wav2vec2", "fairseq", *paths)}.json', 'r') as file_:
with open(f'{get_asset_path("wav2vec2", "fairseq", *paths)}.json', "r") as file_:
return json.load(file_)
def _name_func(testcase_func, i, param):
return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}'
return f"{testcase_func.__name__}_{i}_{param[0][1].__name__}"
# Pretraining models
WAV2VEC2_BASE = _load_config('wav2vec_small')
WAV2VEC2_LARGE = _load_config('libri960_big')
WAV2VEC2_LARGE_LV60K = _load_config('wav2vec_vox_new')
WAV2VEC2_XLSR_53_56K = _load_config('xlsr_53_56k')
HUBERT_BASE = _load_config('hubert_base_ls960')
HUBERT_LARGE_LL60K = _load_config('hubert_large_ll60k')
HUBERT_XLARGE_LL60K = _load_config('hubert_xtralarge_ll60k')
WAV2VEC2_BASE = _load_config("wav2vec_small")
WAV2VEC2_LARGE = _load_config("libri960_big")
WAV2VEC2_LARGE_LV60K = _load_config("wav2vec_vox_new")
WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
HUBERT_BASE = _load_config("hubert_base_ls960")
HUBERT_LARGE_LL60K = _load_config("hubert_large_ll60k")
HUBERT_XLARGE_LL60K = _load_config("hubert_xtralarge_ll60k")
# Finetuning models
WAV2VEC2_BASE_960H = _load_config('wav2vec_small_960h')
WAV2VEC2_LARGE_960H = _load_config('wav2vec_large_960h')
WAV2VEC2_LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h')
WAV2VEC2_LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h')
HUBERT_LARGE = _load_config('hubert_large_ll60k_finetune_ls960')
HUBERT_XLARGE = _load_config('hubert_xtralarge_ll60k_finetune_ls960')
WAV2VEC2_BASE_960H = _load_config("wav2vec_small_960h")
WAV2VEC2_LARGE_960H = _load_config("wav2vec_large_960h")
WAV2VEC2_LARGE_LV60K_960H = _load_config("wav2vec_large_lv60k_960h")
WAV2VEC2_LARGE_LV60K_SELF_960H = _load_config("wav2vec_large_lv60k_self_960h")
HUBERT_LARGE = _load_config("hubert_large_ll60k_finetune_ls960")
HUBERT_XLARGE = _load_config("hubert_xtralarge_ll60k_finetune_ls960")
# Config and corresponding factory functions
WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand([
(WAV2VEC2_BASE, wav2vec2_base),
(WAV2VEC2_LARGE, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
(WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
], name_func=_name_func)
HUBERT_PRETRAINING_CONFIGS = parameterized.expand([
(HUBERT_BASE, hubert_base),
(HUBERT_LARGE_LL60K, hubert_large),
(HUBERT_XLARGE_LL60K, hubert_xlarge),
], name_func=_name_func)
ALL_PRETRAINING_CONFIGS = parameterized.expand([
(WAV2VEC2_BASE, wav2vec2_base),
(WAV2VEC2_LARGE, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
(WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
(HUBERT_BASE, hubert_base),
(HUBERT_LARGE_LL60K, hubert_large),
(HUBERT_XLARGE_LL60K, hubert_xlarge),
], name_func=_name_func)
FINETUNING_CONFIGS = parameterized.expand([
(WAV2VEC2_BASE_960H, wav2vec2_base),
(WAV2VEC2_LARGE_960H, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_large_lv60k),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
(HUBERT_LARGE, hubert_large),
(HUBERT_XLARGE, hubert_xlarge),
], name_func=_name_func)
@skipIfNoModule('fairseq')
WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
[
(WAV2VEC2_BASE, wav2vec2_base),
(WAV2VEC2_LARGE, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
(WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
],
name_func=_name_func,
)
HUBERT_PRETRAINING_CONFIGS = parameterized.expand(
[
(HUBERT_BASE, hubert_base),
(HUBERT_LARGE_LL60K, hubert_large),
(HUBERT_XLARGE_LL60K, hubert_xlarge),
],
name_func=_name_func,
)
ALL_PRETRAINING_CONFIGS = parameterized.expand(
[
(WAV2VEC2_BASE, wav2vec2_base),
(WAV2VEC2_LARGE, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k),
(WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k),
(HUBERT_BASE, hubert_base),
(HUBERT_LARGE_LL60K, hubert_large),
(HUBERT_XLARGE_LL60K, hubert_xlarge),
],
name_func=_name_func,
)
FINETUNING_CONFIGS = parameterized.expand(
[
(WAV2VEC2_BASE_960H, wav2vec2_base),
(WAV2VEC2_LARGE_960H, wav2vec2_large),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_large_lv60k),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
(HUBERT_LARGE, hubert_large),
(HUBERT_XLARGE, hubert_xlarge),
],
name_func=_name_func,
)
@skipIfNoModule("fairseq")
class TestFairseqIntegration(TorchaudioTestCase):
"""Test the process of importing the models from fairseq.
......@@ -86,9 +97,18 @@ class TestFairseqIntegration(TorchaudioTestCase):
1. Models loaded with fairseq cane be imported.
2. The same model can be recreated without fairseq.
"""
def _get_model(self, config, num_out=None):
import copy
from omegaconf import OmegaConf
from fairseq.models.hubert.hubert import (
HubertModel,
HubertConfig,
)
from fairseq.models.hubert.hubert_asr import (
HubertCtcConfig,
HubertEncoder,
)
from fairseq.models.wav2vec.wav2vec2 import (
Wav2Vec2Config,
Wav2Vec2Model,
......@@ -97,32 +117,25 @@ class TestFairseqIntegration(TorchaudioTestCase):
Wav2VecEncoder,
Wav2Vec2CtcConfig,
)
from fairseq.models.hubert.hubert_asr import (
HubertCtcConfig,
HubertEncoder,
)
from fairseq.models.hubert.hubert import (
HubertModel,
HubertConfig,
)
from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig
from omegaconf import OmegaConf
if config['_name'] == 'wav2vec_ctc':
if config["_name"] == "wav2vec_ctc":
config = copy.deepcopy(config)
config['w2v_args'] = OmegaConf.create(config['w2v_args'])
config["w2v_args"] = OmegaConf.create(config["w2v_args"])
return Wav2VecEncoder(Wav2Vec2CtcConfig(**config), num_out)
if config['_name'] == 'wav2vec2':
if config["_name"] == "wav2vec2":
return Wav2Vec2Model(Wav2Vec2Config(**config))
if config['_name'] == 'hubert_ctc':
if config["_name"] == "hubert_ctc":
config = copy.deepcopy(config)
config['w2v_args'] = OmegaConf.create(config['w2v_args'])
config["w2v_args"] = OmegaConf.create(config["w2v_args"])
ctc_cfg = HubertCtcConfig(**config)
return HubertEncoder(ctc_cfg, tgt_dict=range(num_out))
if config['_name'] == 'hubert':
dicts = [list(range(i)) for i in config['num_classes']]
if config["_name"] == "hubert":
dicts = [list(range(i)) for i in config["num_classes"]]
return HubertModel(
HubertConfig(**config['model']),
HubertPretrainingConfig(**config['task']),
HubertConfig(**config["model"]),
HubertPretrainingConfig(**config["task"]),
dicts,
)
raise ValueError(f'Unexpected configuration: {config["_name"]}')
......@@ -139,7 +152,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
x = torch.randn(batch_size, num_frames)
hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs['layer_results']):
for i, (ref, _) in enumerate(refs["layer_results"]):
self.assertEqual(hyp[i], ref.transpose(0, 1))
@HUBERT_PRETRAINING_CONFIGS
......@@ -177,7 +190,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
reloaded.eval()
x = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
# Without mask
ref, _ = imported(x)
hyp, _ = reloaded(x)
......@@ -200,14 +219,20 @@ class TestFairseqIntegration(TorchaudioTestCase):
# Without mask
x = torch.randn(batch_size, num_frames)
ref = original(x, torch.zeros_like(x))['encoder_out'].transpose(0, 1)
ref = original(x, torch.zeros_like(x))["encoder_out"].transpose(0, 1)
hyp, _ = imported(x)
self.assertEqual(ref, hyp)
# With mask
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
mask = torch.arange(num_frames).expand(batch_size, num_frames) >= lengths[:, None]
ref = original(x, mask)['encoder_out'].transpose(0, 1)
ref = original(x, mask)["encoder_out"].transpose(0, 1)
hyp, output_lengths = imported(x, lengths)
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
......@@ -233,7 +258,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp)
# With mask
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref, ref_lengths = imported(x, lengths)
hyp, hyp_lengths = reloaded(x, lengths)
self.assertEqual(ref, hyp)
......
import json
import torch
from parameterized import parameterized
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
from torchaudio.models.wav2vec2.utils import import_huggingface_model
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfNoModule,
......@@ -17,7 +16,7 @@ from torchaudio_unittest.common_utils import (
def _load_config(*paths):
with open(f'{get_asset_path("wav2vec2", "huggingface", *paths)}.json', 'r') as file_:
with open(f'{get_asset_path("wav2vec2", "huggingface", *paths)}.json', "r") as file_:
return json.load(file_)
......@@ -26,36 +25,42 @@ def _name_func(testcase_func, i, param):
# Pretrained
HF_BASE = _load_config('wav2vec2-base')
HF_LARGE = _load_config('wav2vec2-large')
HF_LARGE_LV60 = _load_config('wav2vec2-large-lv60')
HF_LARGE_XLSR_53 = _load_config('wav2vec2-large-xlsr-53')
HF_BASE_10K_VOXPOPULI = _load_config('wav2vec2-base-10k-voxpopuli')
HF_BASE = _load_config("wav2vec2-base")
HF_LARGE = _load_config("wav2vec2-large")
HF_LARGE_LV60 = _load_config("wav2vec2-large-lv60")
HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53")
HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli")
# Finetuned
HF_BASE_960H = _load_config('wav2vec2-base-960h')
HF_LARGE_960H = _load_config('wav2vec2-large-960h')
HF_LARGE_LV60_960H = _load_config('wav2vec2-large-960h-lv60')
HF_LARGE_LV60_SELF_960H = _load_config('wav2vec2-large-960h-lv60-self')
HF_LARGE_XLSR_DE = _load_config('wav2vec2-large-xlsr-53-german')
HF_BASE_960H = _load_config("wav2vec2-base-960h")
HF_LARGE_960H = _load_config("wav2vec2-large-960h")
HF_LARGE_LV60_960H = _load_config("wav2vec2-large-960h-lv60")
HF_LARGE_LV60_SELF_960H = _load_config("wav2vec2-large-960h-lv60-self")
HF_LARGE_XLSR_DE = _load_config("wav2vec2-large-xlsr-53-german")
# Config and corresponding factory functions
PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large),
(HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
], name_func=_name_func)
@skipIfNoModule('transformers')
PRETRAIN_CONFIGS = parameterized.expand(
[
(HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large),
(HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
],
name_func=_name_func,
)
FINETUNE_CONFIGS = parameterized.expand(
[
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_large_lv60k),
],
name_func=_name_func,
)
@skipIfNoModule("transformers")
class TestHFIntegration(TorchaudioTestCase):
"""Test the process of importing the models from Hugging Face Transformers
......@@ -63,6 +68,7 @@ class TestHFIntegration(TorchaudioTestCase):
1. Models loaded with Hugging Face Transformers cane be imported.
2. The same model can be recreated without Hugging Face Transformers.
"""
def _get_model(self, config):
# Helper function to avoid importing transformers on module scope.
# Normally, we use `is_module_available` helper function to check if
......@@ -75,9 +81,10 @@ class TestHFIntegration(TorchaudioTestCase):
Wav2Vec2Model,
Wav2Vec2ForCTC,
)
if config['architectures'] == ['Wav2Vec2Model']:
if config["architectures"] == ["Wav2Vec2Model"]:
return Wav2Vec2Model(Wav2Vec2Config(**config))
if config['architectures'] == ['Wav2Vec2ForCTC']:
if config["architectures"] == ["Wav2Vec2ForCTC"]:
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}')
......@@ -89,12 +96,12 @@ class TestHFIntegration(TorchaudioTestCase):
hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1])
x = torch.randn(3, 10, config["conv_dim"][-1])
ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size'])
x = torch.randn(3, 256, config["hidden_size"])
ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
......@@ -104,7 +111,7 @@ class TestHFIntegration(TorchaudioTestCase):
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
ref, = original_(x, attention_mask=mask, output_attentions=False)
(ref,) = original_(x, attention_mask=mask, output_attentions=False)
hyp = imported_(x, mask)
self.assertEqual(ref, hyp)
# The whole Encoder Transformer
......@@ -135,7 +142,13 @@ class TestHFIntegration(TorchaudioTestCase):
# The whole model with mask
batch_size, num_frames = 3, 1024
x = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
mask = torch.arange(num_frames).expand(batch_size, num_frames) < lengths[:, None]
ref = original(x, attention_mask=mask).logits
......@@ -167,12 +180,12 @@ class TestHFIntegration(TorchaudioTestCase):
hyp, _ = reloaded.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1])
x = torch.randn(3, 10, config["conv_dim"][-1])
ref = imported.encoder.feature_projection(x)
hyp = reloaded.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size'])
x = torch.randn(3, 256, config["hidden_size"])
ref = imported.encoder.transformer.pos_conv_embed(x)
hyp = reloaded.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
......
import os
from typing import Tuple
import torch
import torch.nn.functional as F
from typing import Tuple
from parameterized import parameterized
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
......@@ -18,7 +18,6 @@ from torchaudio_unittest.common_utils import (
skipIfNoCuda,
torch_script,
)
from parameterized import parameterized
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
......@@ -31,14 +30,17 @@ def _name_func(testcase_func, i, param):
return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([
(wav2vec2_base, ),
(wav2vec2_large, ),
(wav2vec2_large_lv60k, ),
(hubert_base, ),
(hubert_large, ),
(hubert_xlarge, ),
], name_func=_name_func)
factory_funcs = parameterized.expand(
[
(wav2vec2_base,),
(wav2vec2_large,),
(wav2vec2_large_lv60k,),
(hubert_base,),
(hubert_large,),
(hubert_xlarge,),
],
name_func=_name_func,
)
class TestWav2Vec2Model(TorchaudioTestCase):
......@@ -49,27 +51,32 @@ class TestWav2Vec2Model(TorchaudioTestCase):
torch.manual_seed(0)
batch_size, num_frames = 3, 1024
waveforms = torch.randn(
batch_size, num_frames, device=device, dtype=dtype)
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
lengths = torch.randint(
low=0, high=num_frames, size=[batch_size, ], device=device)
low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)
model(waveforms, lengths)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype)
self._smoke_test(model, torch.device("cpu"), dtype)
model = wav2vec2_base(aux_num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype)
self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype)
self._smoke_test(model, torch.device("cuda"), dtype)
model = wav2vec2_base(aux_num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype)
self._smoke_test(model, torch.device("cuda"), dtype)
def _feature_extractor_test(self, model):
batch_size, num_frames = 3, 1024
......@@ -79,7 +86,13 @@ class TestWav2Vec2Model(TorchaudioTestCase):
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
# Not providing num_layers returns all the intermediate features from
# tranformer layers
......@@ -114,8 +127,8 @@ class TestWav2Vec2Model(TorchaudioTestCase):
batch_logits, output_lengths = model(waveforms, input_lengths)
for i in range(batch_size):
# Par-sample process without feeding length
single_logit, _ = model(waveforms[i:i + 1, :input_lengths[i]], None)
batch_logit = batch_logits[i:i + 1, :output_lengths[i]]
single_logit, _ = model(waveforms[i : i + 1, : input_lengths[i]], None)
batch_logit = batch_logits[i : i + 1, : output_lengths[i]]
# Convert to probability so that it's easier to interpretate the diff
single_prob = F.softmax(single_logit, dim=2)
......@@ -125,14 +138,12 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@factory_funcs
def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func())
@factory_funcs
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close
"""
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func(aux_num_out=32))
def _test_zero_length(self, model):
......@@ -163,7 +174,13 @@ class TestWav2Vec2Model(TorchaudioTestCase):
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref_out, ref_len = model(waveforms, lengths)
......@@ -177,19 +194,19 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@factory_funcs
def test_pretrain_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true':
if factory_func is hubert_xlarge and os.name == "nt" and os.environ.get("CI") == "true":
self.skipTest(
'hubert_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776')
"hubert_xlarge is known to fail on Windows CI. " "See https://github.com/pytorch/pytorch/issues/65776"
)
self._test_torchscript(factory_func())
@factory_funcs
def test_finetune_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true':
if factory_func is hubert_xlarge and os.name == "nt" and os.environ.get("CI") == "true":
self.skipTest(
'hubert_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776')
"hubert_xlarge is known to fail on Windows CI. " "See https://github.com/pytorch/pytorch/issues/65776"
)
self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model):
......@@ -198,15 +215,20 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
_, _ = quantized(waveforms, lengths)
......@@ -223,15 +245,20 @@ class TestWav2Vec2Model(TorchaudioTestCase):
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref_out, ref_len = quantized(waveforms, lengths)
......
import torch
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
class ConformerFloat32CPUTest(ConformerTestImpl, PytorchTestCase):
......
import torch
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl
@skipIfNoCuda
......
import torch
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio.prototype import Conformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class ConformerTestImpl(TestBaseMixin):
......@@ -24,12 +24,8 @@ class ConformerTestImpl(TestBaseMixin):
return conformer
def _gen_inputs(self, input_dim, batch_size, num_frames):
lengths = torch.randint(1, num_frames, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
input = torch.rand(batch_size, int(lengths.max()), input_dim).to(
device=self.device, dtype=self.dtype
)
lengths = torch.randint(1, num_frames, (batch_size,)).to(device=self.device, dtype=self.dtype)
input = torch.rand(batch_size, int(lengths.max()), input_dim).to(device=self.device, dtype=self.dtype)
return input, lengths
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment