"src/vscode:/vscode.git/clone" did not exist on "459b8ca81a32799aebca681c0ab97e29ddca318f"
Commit 9cf59e75 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor batch consistency test in functional (#2245)

Summary:
In batch_consistency tests, the `assert_batch_consistency` method only accepts single Tensor, which is not applicable to some methods. For example, `lfilter` and `filtfilt` requires three Tensors as the arguments, hence they don't follow `assert_batch_consistency` in the tests.
This PR refactors the test to accept a tuple of Tensors which have `batch` dimension. For the other arguments like `int` or `str`, they are given as `*args` after the tuple.

Pull Request resolved: https://github.com/pytorch/audio/pull/2245

Reviewed By: mthrok

Differential Revision: D34273035

Pulled By: nateanl

fbshipit-source-id: 0096b4f062fb4e983818e5374bed6efc7b15b056
parent 27a6dccc
"""Test numerical consistency among single input and batched input."""
import itertools
import math
from functools import partial
import torch
import torchaudio.functional as F
......@@ -26,20 +27,20 @@ class TestFunctional(common_utils.TorchaudioTestCase):
backend = "default"
def assert_batch_consistency(self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
n = batch.size(0)
def assert_batch_consistency(self, functional, inputs, atol=1e-8, rtol=1e-5, seed=42):
n = inputs[0].size(0)
for i in range(1, len(inputs)):
self.assertEqual(inputs[i].size(0), n)
# Compute items separately, then batch the result
torch.random.manual_seed(seed)
items_input = batch.clone()
items_result = torch.stack([functional(items_input[i], *args, **kwargs) for i in range(n)])
items_input = [[ele[i].clone() for ele in inputs] for i in range(n)]
items_result = torch.stack([functional(*items_input[i]) for i in range(n)])
# Batch the input and run
torch.random.manual_seed(seed)
batch_input = batch.clone()
batch_result = functional(batch_input, *args, **kwargs)
batch_input = [ele.clone() for ele in inputs]
batch_result = functional(*batch_input)
self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)
def test_griffinlim(self):
......@@ -53,9 +54,19 @@ class TestFunctional(common_utils.TorchaudioTestCase):
length = 1000
torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6)
self.assert_batch_consistency(
F.griffinlim, batch, window, n_fft, hop, ws, power, n_iter, momentum, length, 0, atol=5e-5
)
kwargs = {
"window": window,
"n_fft": n_fft,
"hop_length": hop,
"win_length": ws,
"power": power,
"n_iter": n_iter,
"momentum": momentum,
"length": length,
"rand_init": False,
}
func = partial(F.griffinlim, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,), atol=5e-5)
@parameterized.expand(
list(
......@@ -79,9 +90,19 @@ class TestFunctional(common_utils.TorchaudioTestCase):
for frequency in frequencies
]
)
self.assert_batch_consistency(F.detect_pitch_frequency, waveforms, sample_rate)
kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.detect_pitch_frequency, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_amplitude_to_DB(self):
@parameterized.expand(
[
(None,),
(40.0,),
]
)
def test_amplitude_to_DB(self, top_db):
torch.manual_seed(0)
spec = torch.rand(self.batch_size, 2, 100, 100) * 200
......@@ -89,10 +110,15 @@ class TestFunctional(common_utils.TorchaudioTestCase):
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
kwargs = {
"multiplier": amplitude_mult,
"amin": amin,
"db_multiplier": db_mult,
"top_db": top_db,
}
func = partial(F.amplitude_to_DB, **kwargs)
# Test with & without a `top_db` clamp
self.assert_batch_consistency(F.amplitude_to_DB, spec, amplitude_mult, amin, db_mult, top_db=None)
self.assert_batch_consistency(F.amplitude_to_DB, spec, amplitude_mult, amin, db_mult, top_db=40.0)
self.assert_batch_consistency(func, inputs=(spec,))
def test_amplitude_to_DB_itemwise_clamps(self):
"""Ensure that the clamps are separate for each spectrogram in a batch.
......@@ -115,13 +141,14 @@ class TestFunctional(common_utils.TorchaudioTestCase):
spec = torch.rand([2, 2, 100, 100]) * 200
# Make one item blow out the other
spec[0] += 50
batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=top_db)
itemwise_dbs = torch.stack(
[F.amplitude_to_DB(item, amplitude_mult, amin, db_mult, top_db=top_db) for item in spec]
)
self.assertEqual(batchwise_dbs, itemwise_dbs)
kwargs = {
"multiplier": amplitude_mult,
"amin": amin,
"db_multiplier": db_mult,
"top_db": top_db,
}
func = partial(F.amplitude_to_DB, **kwargs)
self.assert_batch_consistency(func, inputs=(spec,))
def test_amplitude_to_DB_not_channelwise_clamps(self):
"""Check that clamps are applied per-item, not per channel."""
......@@ -148,17 +175,31 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_contrast(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(F.contrast, waveforms, enhancement_amount=80.0)
kwargs = {
"enhancement_amount": 80.0,
}
func = partial(F.contrast, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_dcshift(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(F.dcshift, waveforms, shift=0.5, limiter_gain=0.05)
kwargs = {
"shift": 0.5,
"limiter_gain": 0.05,
}
func = partial(F.dcshift, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_overdrive(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(F.overdrive, waveforms, gain=45, colour=30)
kwargs = {
"gain": 45,
"colour": 30,
}
func = partial(F.overdrive, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
def test_phaser(self):
sample_rate = 44100
......@@ -167,13 +208,21 @@ class TestFunctional(common_utils.TorchaudioTestCase):
sample_rate=sample_rate, n_channels=self.batch_size * n_channels, duration=1
)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.phaser, batch, sample_rate)
kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.phaser, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,))
def test_flanger(self):
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100
self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.flanger, **kwargs)
self.assert_batch_consistency(func, inputs=(waveforms,))
@parameterized.expand(
list(
......@@ -187,7 +236,12 @@ class TestFunctional(common_utils.TorchaudioTestCase):
def test_sliding_window_cmn(self, center, norm_vars):
torch.manual_seed(0)
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
self.assert_batch_consistency(F.sliding_window_cmn, spectrogram, center=center, norm_vars=norm_vars)
kwargs = {
"center": center,
"norm_vars": norm_vars,
}
func = partial(F.sliding_window_cmn, **kwargs)
self.assert_batch_consistency(func, inputs=(spectrogram,))
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform(self, resampling_method):
......@@ -199,13 +253,16 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels=num_channels,
duration=0.5,
)
kwargs = {
"orig_freq": sr,
"new_freq": new_sr,
"resampling_method": resampling_method,
}
func = partial(F.resample, **kwargs)
self.assert_batch_consistency(
F.resample,
multi_sound,
orig_freq=sr,
new_freq=new_sr,
resampling_method=resampling_method,
func,
inputs=(multi_sound,),
rtol=1e-4,
atol=1e-7,
)
......@@ -216,7 +273,11 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels = 2
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.compute_kaldi_pitch, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,))
def test_lfilter(self):
signal_length = 2048
......@@ -224,11 +285,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
batchwise_output = F.lfilter(x, a, b, batching=True)
itemwise_output = torch.stack([F.lfilter(x[i], a[i], b[i]) for i in range(self.batch_size)])
self.assertEqual(batchwise_output, itemwise_output)
self.assert_batch_consistency(F.lfilter, inputs=(x, a, b))
def test_filtfilt(self):
signal_length = 2048
......@@ -236,8 +293,4 @@ class TestFunctional(common_utils.TorchaudioTestCase):
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
batchwise_output = F.filtfilt(x, a, b)
itemwise_output = torch.stack([F.filtfilt(x[i], a[i], b[i]) for i in range(self.batch_size)])
self.assertEqual(batchwise_output, itemwise_output)
self.assert_batch_consistency(F.filtfilt, inputs=(x, a, b))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment