Unverified Commit 3047dc9b authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Apply functional batch consistency tests to batches of different items (#1315)

* Test with batches of differing items

Issues that occur when different items in a batch influence one another
will not present when a batch is composed of identical items. When
checking the consistency of batched behavior, in order to catch these
issues items should be different.

Thus, use different items for the `functional` batch consistency tests
wherever possible.

* Generate different white noise in each channel

Don't duplicate a single channel multiple times. Since this is used for
testing, generate different noise in each channel so data leakage
between channels can be detected.

* Parameterize batch size in batch consistency tests

Rather than creating a batch of 3 items in each test and slicing it to
test two different batch sizes at once, parameterize the batch size on
the TestFunctional class itself. This will generate a separate set of
tests for each batch size (better isolating failures) and removes a
leaky abstraction where the test calling `assert_batch_consistencies`
had to know to give it a batch size greater than 1.

* Check inputs too, to catch in-place operations

Check inputs to the batch consistency operations too, to ensure any
in-place operations operate the same on items as batches - not just that
they output the same result.

* Use much shorter sample for phaser test

Using a 5-second signal for the phaser test takes a long time on CPU,
much longer than the other batch consistency tests. Use a shorter signal
instead.

* Load dual-channel wav for VAD test

The stereo wav has two channels, slightly offset, so they'll count as
different items.

* Load wav using common_utils, not torchaudio.load

* Test pitch frequency with different freqs per item

The pitch frequency batch test was using the same frequency for each
item, which may not catch data leakage between items within a batch. Use
different frequencies so these kinds of issues would be triggered, just
like the other batch consistency tests.

* Explain justification for single-item batch
parent 765fde08
......@@ -68,11 +68,11 @@ def get_whitenoise(
# so we only fork on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn([int(sample_rate * duration)], dtype=torch.float32, device='cpu')
tensor = torch.randn([n_channels, int(sample_rate * duration)],
dtype=torch.float32, device='cpu')
tensor /= 2.0
tensor *= scale_factor
tensor.clamp_(-1.0, 1.0)
tensor = tensor.repeat([n_channels, 1])
if not channels_first:
tensor = tensor.t()
return convert_tensor_encoding(tensor, dtype)
......
......@@ -2,41 +2,42 @@
import itertools
import math
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
import torch
import torchaudio
import torchaudio.functional as F
from torchaudio_unittest import common_utils
@parameterized_class([
# Single-item batch isolates problems that come purely from adding a
# dimension (rather than processing multiple items)
{"batch_size": 1},
{"batch_size": 3},
])
class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module"""
backend = 'default'
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8,
rtol=1e-5, seed=42, **kwargs):
# run then batch the result
self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
**kwargs):
n = batch.size(0)
# Compute items separately, then batch the result
torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.repeat([batch_size] + [1] * expected.dim())
items_input = batch.clone()
items_result = torch.stack([
functional(items_input[i], *args, **kwargs) for i in range(n)
])
# batch the input and run
# Batch the input and run
torch.random.manual_seed(seed)
pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
batch_input = batch.clone()
batch_result = functional(batch_input, *args, **kwargs)
self.assertEqual(computed, expected, rtol=rtol, atol=atol)
def assert_batch_consistencies(
self, functional, tensor, *args, atol=1e-8, rtol=1e-5,
seed=42, **kwargs):
self.assert_batch_consistency(
functional, tensor, *args, batch_size=1, atol=atol,
rtol=rtol, seed=seed, **kwargs)
self.assert_batch_consistency(
functional, tensor, *args, batch_size=3, atol=atol,
rtol=rtol, seed=seed, **kwargs)
self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)
def test_griffinlim(self):
n_fft = 400
......@@ -48,26 +49,33 @@ class TestFunctional(common_utils.TorchaudioTestCase):
momentum = 0.99
n_iter = 32
length = 1000
tensor = torch.rand((1, 201, 6))
self.assert_batch_consistencies(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize,
torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6)
self.assert_batch_consistency(
F.griffinlim, batch, window, n_fft, hop, ws, power, normalize,
n_iter, momentum, length, 0, atol=5e-5)
@parameterized.expand(list(itertools.product(
[100, 440],
[8000, 16000, 44100],
[1, 2],
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels):
waveform = common_utils.get_sinusoid(
frequency=frequency, sample_rate=sample_rate,
n_channels=n_channels, duration=5)
self.assert_batch_consistencies(
F.detect_pitch_frequency, waveform, sample_rate)
def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a
# different answer.
torch.manual_seed(0)
frequencies = torch.randint(100, 1000, [self.batch_size])
waveforms = torch.stack([
common_utils.get_sinusoid(
frequency=frequency, sample_rate=sample_rate,
n_channels=n_channels, duration=5)
for frequency in frequencies
])
self.assert_batch_consistency(
F.detect_pitch_frequency, waveforms, sample_rate)
def test_amplitude_to_DB(self):
torch.manual_seed(0)
spec = torch.rand(2, 100, 100) * 200
spec = torch.rand(self.batch_size, 2, 100, 100) * 200
amplitude_mult = 20.
amin = 1e-10
......@@ -75,10 +83,10 @@ class TestFunctional(common_utils.TorchaudioTestCase):
db_mult = math.log10(max(amin, ref))
# Test with & without a `top_db` clamp
self.assert_batch_consistencies(
self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=None)
self.assert_batch_consistencies(
self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.)
......@@ -140,53 +148,70 @@ class TestFunctional(common_utils.TorchaudioTestCase):
assert (difference >= 1e-5).any()
def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(
F.contrast, waveform, enhancement_amount=80.)
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.contrast, waveforms, enhancement_amount=80.)
def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(
F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.dcshift, waveforms, shift=0.5, limiter_gain=0.05)
def test_overdrive(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(
F.overdrive, waveform, gain=45, colour=30)
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.overdrive, waveforms, gain=45, colour=30)
def test_phaser(self):
sample_rate = 44100
n_channels = 2
waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, duration=5,
)
self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
sample_rate=sample_rate, n_channels=self.batch_size * n_channels,
duration=1)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.phaser, batch, sample_rate)
def test_flanger(self):
torch.random.manual_seed(40)
waveform = torch.rand(2, 100) - 0.5
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100
self.assert_batch_consistencies(F.flanger, waveform, sample_rate)
self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=True, norm_vars=True)
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=True, norm_vars=False)
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=False, norm_vars=True)
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=False, norm_vars=False)
def test_vad(self):
common_utils.set_audio_backend('default')
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(
F.vad, waveform, sample_rate=sample_rate)
waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=True)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=False)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=False)
def test_vad_from_file(self):
filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
waveform, sample_rate = common_utils.load_wav(filepath)
# Each channel is slightly offset - we can use this to create a batch
# with different items.
batch = waveform.view(2, 1, -1)
self.assert_batch_consistency(F.vad, batch, sample_rate=sample_rate)
def test_vad_different_items(self):
"""Separate test to ensure VAD consistency with differing items."""
sample_rate = 44100
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)
@common_utils.skipIfNoExtension
def test_compute_kaldi_pitch(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self.assert_batch_consistencies(F.compute_kaldi_pitch, waveform, sample_rate=sample_rate)
n_channels = 2
waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment