Commit 87d79889 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor torchscript consistency test in functional (#2246)

Summary:
In torchscript_consistency tests, the `func` in each test method only accepts one `tensor` as the argument, for the other arguments of `F.xyz` method, they need to be defined inside the `func`. If there is no `Tensor` argument in `F.xzy`, the tests use a `dummy` tensor which is not used anywhere. In this PR, we refactor ``_assert_consistency`` and ``_assert_consistency_complex`` to accept a tuple of inputs instead of just one `tensor`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2246

Reviewed By: carolineechen

Differential Revision: D34273057

Pulled By: nateanl

fbshipit-source-id: a3900edb3b2c58638e513e1490279d771ebc3d0b
parent fdea0a7c
......@@ -3,6 +3,7 @@ import unittest
import torch
import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
......@@ -15,90 +16,91 @@ from torchaudio_unittest.common_utils import (
class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
def _assert_consistency(self, func, inputs, shape_only=False):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(device=self.device, dtype=self.dtype)
inputs_.append(i)
ts_func = torch_script(func)
torch.random.manual_seed(40)
output = func(tensor)
output = func(*inputs_)
torch.random.manual_seed(40)
ts_output = ts_func(tensor)
ts_output = ts_func(*inputs_)
if shape_only:
ts_output = ts_output.shape
output = output.shape
self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, func, tensor):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
def _assert_consistency_complex(self, func, inputs):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
inputs_.append(i)
ts_func = torch_script(func)
torch.random.manual_seed(40)
output = func(tensor)
output = func(*inputs_)
torch.random.manual_seed(40)
ts_output = ts_func(tensor)
ts_output = ts_func(*inputs_)
self.assertEqual(ts_output, output)
def test_spectrogram(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = None
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
waveform = common_utils.get_whitenoise()
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=waveform.device, dtype=waveform.dtype)
power = None
normalize = False
self._assert_consistency(
F.spectrogram, (waveform, pad, window, n_fft, hop, ws, power, normalize, True, "reflect", True, True)
)
def test_inverse_spectrogram(self):
def func(tensor):
length = 400
n_fft = 400
hop = 200
ws = 400
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=torch.float64)
normalize = False
return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
self._assert_consistency_complex(func, tensor)
specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
length = 400
n_fft = 400
hop = 200
ws = 400
pad = 0
window = torch.hann_window(ws, device=specgram.device, dtype=torch.float64)
normalize = False
self._assert_consistency_complex(
F.inverse_spectrogram, (specgram, length, pad, window, n_fft, hop, ws, normalize, True, "reflect", True)
)
@skipIfRocm
def test_griffinlim(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.0
momentum = 0.99
n_iter = 32
length = 1000
rand_int = False
return F.griffinlim(tensor, window, n_fft, hop, ws, power, n_iter, momentum, length, rand_int)
tensor = torch.rand((1, 201, 6))
self._assert_consistency(func, tensor)
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.0
momentum = 0.99
n_iter = 32
length = 1000
rand_int = False
self._assert_consistency(
F.griffinlim, (tensor, window, n_fft, hop, ws, power, n_iter, momentum, length, rand_int)
)
def test_compute_deltas(self):
def func(tensor):
win_length = 2 * 7 + 1
return F.compute_deltas(tensor, win_length=win_length)
channel = 13
n_mfcc = channel * 3
time = 1021
tensor = torch.randn(channel, n_mfcc, time)
self._assert_consistency(func, tensor)
win_length = 2 * 7 + 1
self._assert_consistency(F.compute_deltas, (tensor, win_length, "replicate"))
def test_detect_pitch_frequency(self):
waveform = common_utils.get_sinusoid(sample_rate=44100)
......@@ -107,71 +109,53 @@ class Functional(TempDirMixin, TestBaseMixin):
sample_rate = 44100
return F.detect_pitch_frequency(tensor, sample_rate)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_melscale_fbanks(self):
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000
norm = "slaney"
return F.melscale_fbanks(n_stft, f_min, f_max, n_mels, sample_rate, norm)
dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000
norm = "slaney"
self._assert_consistency(F.melscale_fbanks, (n_stft, f_min, f_max, n_mels, sample_rate, norm, "htk"))
def test_linear_fbanks(self):
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_filter = 10
sample_rate = 16000
return F.linear_fbanks(n_stft, f_min, f_max, n_filter, sample_rate)
dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
n_stft = 100
f_min = 0.0
f_max = 20.0
n_filter = 10
sample_rate = 16000
self._assert_consistency(F.linear_fbanks, (n_stft, f_min, f_max, n_filter, sample_rate))
def test_amplitude_to_DB(self):
def func(tensor):
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
tensor = torch.rand((6, 201))
self._assert_consistency(func, tensor)
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
self._assert_consistency(F.amplitude_to_DB, (tensor, multiplier, amin, db_multiplier, top_db))
def test_DB_to_amplitude(self):
def func(tensor):
ref = 1.0
power = 1.0
return F.DB_to_amplitude(tensor, ref, power)
tensor = torch.rand((1, 100))
self._assert_consistency(func, tensor)
ref = 1.0
power = 1.0
self._assert_consistency(F.DB_to_amplitude, (tensor, ref, power))
def test_create_dct(self):
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
def func(_):
n_mfcc = 40
n_mels = 128
norm = "ortho"
return F.create_dct(n_mfcc, n_mels, norm)
dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
n_mfcc = 40
n_mels = 128
norm = "ortho"
self._assert_consistency(F.create_dct, (n_mfcc, n_mels, norm))
def test_mu_law_encoding(self):
def func(tensor):
......@@ -179,7 +163,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.mu_law_encoding(tensor, qc)
waveform = common_utils.get_whitenoise()
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_mu_law_decoding(self):
def func(tensor):
......@@ -187,7 +171,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.mu_law_decoding(tensor, qc)
tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
def test_mask_along_axis(self):
def func(tensor):
......@@ -197,7 +181,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.mask_along_axis(tensor, mask_param, mask_value, axis)
tensor = torch.randn(2, 1025, 400)
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
def test_mask_along_axis_iid(self):
def func(tensor):
......@@ -207,7 +191,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
tensor = torch.randn(4, 2, 1025, 400)
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
def test_gain(self):
def func(tensor):
......@@ -215,88 +199,81 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.gain(tensor, gainDB)
tensor = torch.rand((1, 1000))
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
def test_dither_TPDF(self):
def func(tensor):
return F.dither(tensor, "TPDF")
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
self._assert_consistency(func, (tensor,), shape_only=True)
def test_dither_RPDF(self):
def func(tensor):
return F.dither(tensor, "RPDF")
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
self._assert_consistency(func, (tensor,), shape_only=True)
def test_dither_GPDF(self):
def func(tensor):
return F.dither(tensor, "GPDF")
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
self._assert_consistency(func, (tensor,), shape_only=True)
def test_dither_noise_shaping(self):
def func(tensor):
return F.dither(tensor, noise_shaping=True)
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
def test_lfilter(self):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
waveform = common_utils.get_whitenoise()
def func(tensor):
# Design an IIR lowpass filter using scipy.signal filter design
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
#
# Example
# >>> from scipy.signal import iirdesign
# >>> b, a = iirdesign(0.2, 0.3, 1, 60)
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=tensor.device,
dtype=tensor.dtype,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=tensor.device,
dtype=tensor.dtype,
)
return F.lfilter(tensor, a_coeffs, b_coeffs)
self._assert_consistency(func, waveform)
# Design an IIR lowpass filter using scipy.signal filter design
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
#
# Example
# >>> from scipy.signal import iirdesign
# >>> b, a = iirdesign(0.2, 0.3, 1, 60)
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=waveform.device,
dtype=waveform.dtype,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=waveform.device,
dtype=waveform.dtype,
)
self._assert_consistency(F.lfilter, (waveform, a_coeffs, b_coeffs, True, True))
def test_filtfilt(self):
def func(tensor):
torch.manual_seed(296)
b_coeffs = torch.rand(4, device=tensor.device, dtype=tensor.dtype)
a_coeffs = torch.rand(4, device=tensor.device, dtype=tensor.dtype)
return F.filtfilt(tensor, a_coeffs, b_coeffs)
torch.manual_seed(296)
waveform = common_utils.get_whitenoise(sample_rate=8000)
self._assert_consistency(func, waveform)
b_coeffs = torch.rand(4, device=waveform.device, dtype=waveform.dtype)
a_coeffs = torch.rand(4, device=waveform.device, dtype=waveform.dtype)
self._assert_consistency(F.filtfilt, (waveform, a_coeffs, b_coeffs, True))
def test_lowpass(self):
if self.dtype == torch.float64:
......@@ -309,7 +286,7 @@ class Functional(TempDirMixin, TestBaseMixin):
cutoff_freq = 3000.0
return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_highpass(self):
if self.dtype == torch.float64:
......@@ -322,7 +299,7 @@ class Functional(TempDirMixin, TestBaseMixin):
cutoff_freq = 2000.0
return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_allpass(self):
if self.dtype == torch.float64:
......@@ -336,7 +313,7 @@ class Functional(TempDirMixin, TestBaseMixin):
q = 0.707
return F.allpass_biquad(tensor, sample_rate, central_freq, q)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_bandpass_with_csg(self):
if self.dtype == torch.float64:
......@@ -351,7 +328,7 @@ class Functional(TempDirMixin, TestBaseMixin):
const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_bandpass_without_csg(self):
if self.dtype == torch.float64:
......@@ -366,7 +343,7 @@ class Functional(TempDirMixin, TestBaseMixin):
const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_bandreject(self):
if self.dtype == torch.float64:
......@@ -380,7 +357,7 @@ class Functional(TempDirMixin, TestBaseMixin):
q = 0.707
return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_band_with_noise(self):
if self.dtype == torch.float64:
......@@ -395,7 +372,7 @@ class Functional(TempDirMixin, TestBaseMixin):
noise = True
return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_band_without_noise(self):
if self.dtype == torch.float64:
......@@ -410,7 +387,7 @@ class Functional(TempDirMixin, TestBaseMixin):
noise = False
return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_treble(self):
if self.dtype == torch.float64:
......@@ -425,7 +402,7 @@ class Functional(TempDirMixin, TestBaseMixin):
q = 0.707
return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_bass(self):
if self.dtype == torch.float64:
......@@ -440,7 +417,7 @@ class Functional(TempDirMixin, TestBaseMixin):
q = 0.707
return F.bass_biquad(tensor, sample_rate, gain, central_freq, q)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_deemph(self):
if self.dtype == torch.float64:
......@@ -452,7 +429,7 @@ class Functional(TempDirMixin, TestBaseMixin):
sample_rate = 44100
return F.deemph_biquad(tensor, sample_rate)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_riaa(self):
if self.dtype == torch.float64:
......@@ -464,7 +441,7 @@ class Functional(TempDirMixin, TestBaseMixin):
sample_rate = 44100
return F.riaa_biquad(tensor, sample_rate)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_equalizer(self):
if self.dtype == torch.float64:
......@@ -479,7 +456,7 @@ class Functional(TempDirMixin, TestBaseMixin):
q = 0.707
return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_perf_biquad_filtering(self):
if self.dtype == torch.float64:
......@@ -492,7 +469,7 @@ class Functional(TempDirMixin, TestBaseMixin):
b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
return F.lfilter(tensor, a, b)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_sliding_window_cmn(self):
def func(tensor):
......@@ -508,7 +485,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
b = torch.tensor([[-1.8701, -0.1196], [1.8701, 0.1196]])
self._assert_consistency(func, b)
self._assert_consistency(func, (b,))
def test_contrast(self):
waveform = common_utils.get_whitenoise()
......@@ -517,7 +494,7 @@ class Functional(TempDirMixin, TestBaseMixin):
enhancement_amount = 80.0
return F.contrast(tensor, enhancement_amount)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_dcshift(self):
waveform = common_utils.get_whitenoise()
......@@ -527,7 +504,7 @@ class Functional(TempDirMixin, TestBaseMixin):
limiter_gain = 0.05
return F.dcshift(tensor, shift, limiter_gain)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_overdrive(self):
waveform = common_utils.get_whitenoise()
......@@ -537,7 +514,7 @@ class Functional(TempDirMixin, TestBaseMixin):
colour = 50.0
return F.overdrive(tensor, gain, colour)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_phaser(self):
waveform = common_utils.get_whitenoise(sample_rate=44100)
......@@ -551,7 +528,7 @@ class Functional(TempDirMixin, TestBaseMixin):
sample_rate = 44100
return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_flanger(self):
torch.random.manual_seed(40)
......@@ -578,7 +555,7 @@ class Functional(TempDirMixin, TestBaseMixin):
interpolation="linear",
)
self._assert_consistency(func, waveform)
self._assert_consistency(func, (waveform,))
def test_spectral_centroid(self):
def func(tensor):
......@@ -591,7 +568,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.spectral_centroid(tensor, sample_rate, pad, window, n_fft, hop, ws)
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
......@@ -603,7 +580,7 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.compute_kaldi_pitch(tensor, sample_rate)
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
def test_resample_sinc(self):
def func(tensor):
......@@ -611,38 +588,34 @@ class Functional(TempDirMixin, TestBaseMixin):
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")
tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
def test_resample_kaiser(self):
def func(tensor):
sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")
def func_beta(tensor):
sr1, sr2 = 16000, 8000
beta = 6.0
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)
self._assert_consistency(func, (tensor,))
@parameterized.expand(
[
(None,),
(6.0,),
]
)
def test_resample_kaiser(self, beta):
tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor)
sr1, sr2 = 16000, 8000
lowpass_filter_width = 6
rolloff = 0.99
self._assert_consistency(F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "kaiser_window", beta))
def test_phase_vocoder(self):
def func(tensor):
n_freq = tensor.size(-2)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=torch.real(tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor)
n_freq = tensor.size(-2)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=torch.real(tensor).dtype,
device=tensor.device,
)[..., None]
self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance))
class FunctionalFloat32Only(TestBaseMixin):
......@@ -662,4 +635,4 @@ class FunctionalFloat32Only(TestBaseMixin):
]
)
tensor = logits.to(device=self.device, dtype=torch.float32)
self._assert_consistency(func, tensor)
self._assert_consistency(func, (tensor,))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment