Commit 9cf59e75 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor batch consistency test in functional (#2245)

Summary:
In batch_consistency tests, the `assert_batch_consistency` method only accepts single Tensor, which is not applicable to some methods. For example, `lfilter` and `filtfilt` requires three Tensors as the arguments, hence they don't follow `assert_batch_consistency` in the tests.
This PR refactors the test to accept a tuple of Tensors which have `batch` dimension. For the other arguments like `int` or `str`, they are given as `*args` after the tuple.

Pull Request resolved: https://github.com/pytorch/audio/pull/2245

Reviewed By: mthrok

Differential Revision: D34273035

Pulled By: nateanl

fbshipit-source-id: 0096b4f062fb4e983818e5374bed6efc7b15b056
parent 27a6dccc
"""Test numerical consistency among single input and batched input.""" """Test numerical consistency among single input and batched input."""
import itertools import itertools
import math import math
from functools import partial
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
...@@ -26,20 +27,20 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -26,20 +27,20 @@ class TestFunctional(common_utils.TorchaudioTestCase):
backend = "default" backend = "default"
def assert_batch_consistency(self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs): def assert_batch_consistency(self, functional, inputs, atol=1e-8, rtol=1e-5, seed=42):
n = batch.size(0) n = inputs[0].size(0)
for i in range(1, len(inputs)):
self.assertEqual(inputs[i].size(0), n)
# Compute items separately, then batch the result # Compute items separately, then batch the result
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
items_input = batch.clone() items_input = [[ele[i].clone() for ele in inputs] for i in range(n)]
items_result = torch.stack([functional(items_input[i], *args, **kwargs) for i in range(n)]) items_result = torch.stack([functional(*items_input[i]) for i in range(n)])
# Batch the input and run # Batch the input and run
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
batch_input = batch.clone() batch_input = [ele.clone() for ele in inputs]
batch_result = functional(batch_input, *args, **kwargs) batch_result = functional(*batch_input)
self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol) self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)
def test_griffinlim(self): def test_griffinlim(self):
...@@ -53,9 +54,19 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -53,9 +54,19 @@ class TestFunctional(common_utils.TorchaudioTestCase):
length = 1000 length = 1000
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6) batch = torch.rand(self.batch_size, 1, 201, 6)
self.assert_batch_consistency( kwargs = {
F.griffinlim, batch, window, n_fft, hop, ws, power, n_iter, momentum, length, 0, atol=5e-5 "window": window,
) "n_fft": n_fft,
"hop_length": hop,
"win_length": ws,
"power": power,
"n_iter": n_iter,
"momentum": momentum,
"length": length,
"rand_init": False,
}
func = partial(F.griffinlim, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,), atol=5e-5)
@parameterized.expand( @parameterized.expand(
list( list(
...@@ -79,9 +90,19 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -79,9 +90,19 @@ class TestFunctional(common_utils.TorchaudioTestCase):
for frequency in frequencies for frequency in frequencies
] ]
) )
self.assert_batch_consistency(F.detect_pitch_frequency, waveforms, sample_rate) kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.detect_pitch_frequency, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_amplitude_to_DB(self): @parameterized.expand(
[
(None,),
(40.0,),
]
)
def test_amplitude_to_DB(self, top_db):
torch.manual_seed(0) torch.manual_seed(0)
spec = torch.rand(self.batch_size, 2, 100, 100) * 200 spec = torch.rand(self.batch_size, 2, 100, 100) * 200
...@@ -89,10 +110,15 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -89,10 +110,15 @@ class TestFunctional(common_utils.TorchaudioTestCase):
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
kwargs = {
"multiplier": amplitude_mult,
"amin": amin,
"db_multiplier": db_mult,
"top_db": top_db,
}
func = partial(F.amplitude_to_DB, **kwargs)
# Test with & without a `top_db` clamp # Test with & without a `top_db` clamp
self.assert_batch_consistency(F.amplitude_to_DB, spec, amplitude_mult, amin, db_mult, top_db=None) self.assert_batch_consistency(func, inputs=(spec,))
self.assert_batch_consistency(F.amplitude_to_DB, spec, amplitude_mult, amin, db_mult, top_db=40.0)
def test_amplitude_to_DB_itemwise_clamps(self): def test_amplitude_to_DB_itemwise_clamps(self):
"""Ensure that the clamps are separate for each spectrogram in a batch. """Ensure that the clamps are separate for each spectrogram in a batch.
...@@ -115,13 +141,14 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -115,13 +141,14 @@ class TestFunctional(common_utils.TorchaudioTestCase):
spec = torch.rand([2, 2, 100, 100]) * 200 spec = torch.rand([2, 2, 100, 100]) * 200
# Make one item blow out the other # Make one item blow out the other
spec[0] += 50 spec[0] += 50
kwargs = {
batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=top_db) "multiplier": amplitude_mult,
itemwise_dbs = torch.stack( "amin": amin,
[F.amplitude_to_DB(item, amplitude_mult, amin, db_mult, top_db=top_db) for item in spec] "db_multiplier": db_mult,
) "top_db": top_db,
}
self.assertEqual(batchwise_dbs, itemwise_dbs) func = partial(F.amplitude_to_DB, **kwargs)
self.assert_batch_consistency(func, inputs=(spec,))
def test_amplitude_to_DB_not_channelwise_clamps(self): def test_amplitude_to_DB_not_channelwise_clamps(self):
"""Check that clamps are applied per-item, not per channel.""" """Check that clamps are applied per-item, not per channel."""
...@@ -148,17 +175,31 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -148,17 +175,31 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_contrast(self): def test_contrast(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(F.contrast, waveforms, enhancement_amount=80.0) kwargs = {
"enhancement_amount": 80.0,
}
func = partial(F.contrast, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_dcshift(self): def test_dcshift(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(F.dcshift, waveforms, shift=0.5, limiter_gain=0.05) kwargs = {
"shift": 0.5,
"limiter_gain": 0.05,
}
func = partial(F.dcshift, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_overdrive(self): def test_overdrive(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(F.overdrive, waveforms, gain=45, colour=30) kwargs = {
"gain": 45,
"colour": 30,
}
func = partial(F.overdrive, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_phaser(self): def test_phaser(self):
sample_rate = 44100 sample_rate = 44100
...@@ -167,13 +208,21 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -167,13 +208,21 @@ class TestFunctional(common_utils.TorchaudioTestCase):
sample_rate=sample_rate, n_channels=self.batch_size * n_channels, duration=1 sample_rate=sample_rate, n_channels=self.batch_size * n_channels, duration=1
) )
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.phaser, batch, sample_rate) kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.phaser, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,))
def test_flanger(self): def test_flanger(self):
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100 sample_rate = 44100
self.assert_batch_consistency(F.flanger, waveforms, sample_rate) kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.flanger, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
@parameterized.expand( @parameterized.expand(
list( list(
...@@ -187,7 +236,12 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -187,7 +236,12 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_sliding_window_cmn(self, center, norm_vars): def test_sliding_window_cmn(self, center, norm_vars):
torch.manual_seed(0) torch.manual_seed(0)
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200 spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
self.assert_batch_consistency(F.sliding_window_cmn, spectrogram, center=center, norm_vars=norm_vars) kwargs = {
"center": center,
"norm_vars": norm_vars,
}
func = partial(F.sliding_window_cmn, **kwargs)
self.assert_batch_consistency(func, inputs=(spectrogram,))
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform(self, resampling_method): def test_resample_waveform(self, resampling_method):
...@@ -199,13 +253,16 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -199,13 +253,16 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels=num_channels, n_channels=num_channels,
duration=0.5, duration=0.5,
) )
kwargs = {
"orig_freq": sr,
"new_freq": new_sr,
"resampling_method": resampling_method,
}
func = partial(F.resample, **kwargs)
self.assert_batch_consistency( self.assert_batch_consistency(
F.resample, func,
multi_sound, inputs=(multi_sound,),
orig_freq=sr,
new_freq=new_sr,
resampling_method=resampling_method,
rtol=1e-4, rtol=1e-4,
atol=1e-7, atol=1e-7,
) )
...@@ -216,7 +273,11 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -216,7 +273,11 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels = 2 n_channels = 2
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=self.batch_size * n_channels) waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.compute_kaldi_pitch, batch, sample_rate=sample_rate) kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.compute_kaldi_pitch, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,))
def test_lfilter(self): def test_lfilter(self):
signal_length = 2048 signal_length = 2048
...@@ -224,11 +285,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -224,11 +285,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
x = torch.randn(self.batch_size, signal_length) x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3) a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3) b = torch.rand(self.batch_size, 3)
self.assert_batch_consistency(F.lfilter, inputs=(x, a, b))
batchwise_output = F.lfilter(x, a, b, batching=True)
itemwise_output = torch.stack([F.lfilter(x[i], a[i], b[i]) for i in range(self.batch_size)])
self.assertEqual(batchwise_output, itemwise_output)
def test_filtfilt(self): def test_filtfilt(self):
signal_length = 2048 signal_length = 2048
...@@ -236,8 +293,4 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -236,8 +293,4 @@ class TestFunctional(common_utils.TorchaudioTestCase):
x = torch.randn(self.batch_size, signal_length) x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3) a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3) b = torch.rand(self.batch_size, 3)
self.assert_batch_consistency(F.filtfilt, inputs=(x, a, b))
batchwise_output = F.filtfilt(x, a, b)
itemwise_output = torch.stack([F.filtfilt(x[i], a[i], b[i]) for i in range(self.batch_size)])
self.assertEqual(batchwise_output, itemwise_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment