Unverified Commit 3047dc9b authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Apply functional batch consistency tests to batches of different items (#1315)

* Test with batches of differing items

Issues that occur when different items in a batch influence one another
will not present when a batch is composed of identical items. When
checking the consistency of batched behavior, in order to catch these
issues items should be different.

Thus, use different items for the `functional` batch consistency tests
wherever possible.

* Generate different white noise in each channel

Don't duplicate a single channel multiple times. Since this is used for
testing, generate different noise in each channel so data leakage
between channels can be detected.

* Parameterize batch size in batch consistency tests

Rather than creating a batch of 3 items in each test and slicing it to
test two different batch sizes at once, parameterize the batch size on
the TestFunctional class itself. This will generate a separate set of
tests for each batch size (better isolating failures) and removes a
leaky abstraction where the test calling `assert_batch_consistencies`
had to know to give it a batch size greater than 1.

* Check inputs too, to catch in-place operations

Check inputs to the batch consistency operations too, to ensure any
in-place operations operate the same on items as batches - not just that
they output the same result.

* Use much shorter sample for phaser test

Using a 5-second signal for the phaser test takes a long time on CPU,
much longer than the other batch consistency tests. Use a shorter signal
instead.

* Load dual-channel wav for VAD test

The stereo wav has two channels, slightly offset, so they'll count as
different items.

* Load wav using common_utils, not torchaudio.load

* Test pitch frequency with different freqs per item

The pitch frequency batch test was using the same frequency for each
item, which may not catch data leakage between items within a batch. Use
different frequencies so these kinds of issues would be triggered, just
like the other batch consistency tests.

* Explain justification for single-item batch
parent 765fde08
...@@ -68,11 +68,11 @@ def get_whitenoise( ...@@ -68,11 +68,11 @@ def get_whitenoise(
# so we only fork on CPU, generate values and move the data to the given device # so we only fork on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]): with torch.random.fork_rng([]):
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
tensor = torch.randn([int(sample_rate * duration)], dtype=torch.float32, device='cpu') tensor = torch.randn([n_channels, int(sample_rate * duration)],
dtype=torch.float32, device='cpu')
tensor /= 2.0 tensor /= 2.0
tensor *= scale_factor tensor *= scale_factor
tensor.clamp_(-1.0, 1.0) tensor.clamp_(-1.0, 1.0)
tensor = tensor.repeat([n_channels, 1])
if not channels_first: if not channels_first:
tensor = tensor.t() tensor = tensor.t()
return convert_tensor_encoding(tensor, dtype) return convert_tensor_encoding(tensor, dtype)
......
...@@ -2,41 +2,42 @@ ...@@ -2,41 +2,42 @@
import itertools import itertools
import math import math
from parameterized import parameterized from parameterized import parameterized, parameterized_class
import torch import torch
import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
@parameterized_class([
# Single-item batch isolates problems that come purely from adding a
# dimension (rather than processing multiple items)
{"batch_size": 1},
{"batch_size": 3},
])
class TestFunctional(common_utils.TorchaudioTestCase): class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module""" """Test functions defined in `functional` module"""
backend = 'default'
def assert_batch_consistency( def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
rtol=1e-5, seed=42, **kwargs): **kwargs):
# run then batch the result n = batch.size(0)
# Compute items separately, then batch the result
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs) items_input = batch.clone()
expected = expected.repeat([batch_size] + [1] * expected.dim()) items_result = torch.stack([
functional(items_input[i], *args, **kwargs) for i in range(n)
])
# batch the input and run # Batch the input and run
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
pattern = [batch_size] + [1] * tensor.dim() batch_input = batch.clone()
computed = functional(tensor.repeat(pattern), *args, **kwargs) batch_result = functional(batch_input, *args, **kwargs)
self.assertEqual(computed, expected, rtol=rtol, atol=atol) self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)
def assert_batch_consistencies(
self, functional, tensor, *args, atol=1e-8, rtol=1e-5,
seed=42, **kwargs):
self.assert_batch_consistency(
functional, tensor, *args, batch_size=1, atol=atol,
rtol=rtol, seed=seed, **kwargs)
self.assert_batch_consistency(
functional, tensor, *args, batch_size=3, atol=atol,
rtol=rtol, seed=seed, **kwargs)
def test_griffinlim(self): def test_griffinlim(self):
n_fft = 400 n_fft = 400
...@@ -48,26 +49,33 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -48,26 +49,33 @@ class TestFunctional(common_utils.TorchaudioTestCase):
momentum = 0.99 momentum = 0.99
n_iter = 32 n_iter = 32
length = 1000 length = 1000
tensor = torch.rand((1, 201, 6)) torch.random.manual_seed(0)
self.assert_batch_consistencies( batch = torch.rand(self.batch_size, 1, 201, 6)
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, self.assert_batch_consistency(
F.griffinlim, batch, window, n_fft, hop, ws, power, normalize,
n_iter, momentum, length, 0, atol=5e-5) n_iter, momentum, length, 0, atol=5e-5)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[100, 440],
[8000, 16000, 44100], [8000, 16000, 44100],
[1, 2], [1, 2],
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels): def test_detect_pitch_frequency(self, sample_rate, n_channels):
waveform = common_utils.get_sinusoid( # Use different frequencies to ensure each item in the batch returns a
# different answer.
torch.manual_seed(0)
frequencies = torch.randint(100, 1000, [self.batch_size])
waveforms = torch.stack([
common_utils.get_sinusoid(
frequency=frequency, sample_rate=sample_rate, frequency=frequency, sample_rate=sample_rate,
n_channels=n_channels, duration=5) n_channels=n_channels, duration=5)
self.assert_batch_consistencies( for frequency in frequencies
F.detect_pitch_frequency, waveform, sample_rate) ])
self.assert_batch_consistency(
F.detect_pitch_frequency, waveforms, sample_rate)
def test_amplitude_to_DB(self): def test_amplitude_to_DB(self):
torch.manual_seed(0) torch.manual_seed(0)
spec = torch.rand(2, 100, 100) * 200 spec = torch.rand(self.batch_size, 2, 100, 100) * 200
amplitude_mult = 20. amplitude_mult = 20.
amin = 1e-10 amin = 1e-10
...@@ -75,10 +83,10 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -75,10 +83,10 @@ class TestFunctional(common_utils.TorchaudioTestCase):
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
# Test with & without a `top_db` clamp # Test with & without a `top_db` clamp
self.assert_batch_consistencies( self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult, F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=None) amin, db_mult, top_db=None)
self.assert_batch_consistencies( self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult, F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.) amin, db_mult, top_db=40.)
...@@ -140,53 +148,70 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -140,53 +148,70 @@ class TestFunctional(common_utils.TorchaudioTestCase):
assert (difference >= 1e-5).any() assert (difference >= 1e-5).any()
def test_contrast(self): def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5 torch.random.manual_seed(0)
self.assert_batch_consistencies( waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
F.contrast, waveform, enhancement_amount=80.) self.assert_batch_consistency(
F.contrast, waveforms, enhancement_amount=80.)
def test_dcshift(self): def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5 torch.random.manual_seed(0)
self.assert_batch_consistencies( waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
F.dcshift, waveform, shift=0.5, limiter_gain=0.05) self.assert_batch_consistency(
F.dcshift, waveforms, shift=0.5, limiter_gain=0.05)
def test_overdrive(self): def test_overdrive(self):
waveform = torch.rand(2, 100) - 0.5 torch.random.manual_seed(0)
self.assert_batch_consistencies( waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
F.overdrive, waveform, gain=45, colour=30) self.assert_batch_consistency(
F.overdrive, waveforms, gain=45, colour=30)
def test_phaser(self): def test_phaser(self):
sample_rate = 44100 sample_rate = 44100
n_channels = 2
waveform = common_utils.get_whitenoise( waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, duration=5, sample_rate=sample_rate, n_channels=self.batch_size * n_channels,
) duration=1)
self.assert_batch_consistencies(F.phaser, waveform, sample_rate) batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.phaser, batch, sample_rate)
def test_flanger(self): def test_flanger(self):
torch.random.manual_seed(40) torch.random.manual_seed(0)
waveform = torch.rand(2, 100) - 0.5 waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100 sample_rate = 44100
self.assert_batch_consistencies(F.flanger, waveform, sample_rate) self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5 waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5
self.assert_batch_consistencies( self.assert_batch_consistency(
F.sliding_window_cmn, waveform, center=True, norm_vars=True) F.sliding_window_cmn, waveforms, center=True, norm_vars=True)
self.assert_batch_consistencies( self.assert_batch_consistency(
F.sliding_window_cmn, waveform, center=True, norm_vars=False) F.sliding_window_cmn, waveforms, center=True, norm_vars=False)
self.assert_batch_consistencies( self.assert_batch_consistency(
F.sliding_window_cmn, waveform, center=False, norm_vars=True) F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
self.assert_batch_consistencies( self.assert_batch_consistency(
F.sliding_window_cmn, waveform, center=False, norm_vars=False) F.sliding_window_cmn, waveforms, center=False, norm_vars=False)
def test_vad(self): def test_vad_from_file(self):
common_utils.set_audio_backend('default') filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav") waveform, sample_rate = common_utils.load_wav(filepath)
waveform, sample_rate = torchaudio.load(filepath) # Each channel is slightly offset - we can use this to create a batch
self.assert_batch_consistencies( # with different items.
F.vad, waveform, sample_rate=sample_rate) batch = waveform.view(2, 1, -1)
self.assert_batch_consistency(F.vad, batch, sample_rate=sample_rate)
def test_vad_different_items(self):
"""Separate test to ensure VAD consistency with differing items."""
sample_rate = 44100
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)
@common_utils.skipIfNoExtension @common_utils.skipIfNoExtension
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
sample_rate = 44100 sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate) n_channels = 2
self.assert_batch_consistencies(F.compute_kaldi_pitch, waveform, sample_rate=sample_rate) waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment