Unverified Commit b059f087 authored by dhthompson's avatar dhthompson Committed by GitHub
Browse files

Refactor functional test (#1463)

- Put functional test logic into one place, `functional_impl.py`
- Tidy imports
parent 7355d9fd
import math
import warnings
import torch import torch
import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized
import itertools
import unittest import unittest
from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import PytorchTestCase, TorchaudioTestCase, skipIfNoSox
from torchaudio_unittest.common_utils import ( from .functional_impl import Functional, FunctionalComplex, FunctionalCPUOnly
TorchaudioTestCase,
skipIfNoSox,
)
from .functional_impl import Functional, FunctionalComplex
class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase): class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
...@@ -26,207 +16,23 @@ class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase): ...@@ -26,207 +16,23 @@ class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase):
super().test_lfilter_9th_order_filter_stability() super().test_lfilter_9th_order_filter_stability()
class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase): class TestFunctionalFloat64(Functional, FunctionalCPUOnly, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase): class TestFunctionalComplex64(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase):
complex_dtype = torch.complex64 complex_dtype = torch.complex64
real_dtype = torch.float32 real_dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase): class TestFunctionalComplex128(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase):
complex_dtype = torch.complex128 complex_dtype = torch.complex128
real_dtype = torch.float64 real_dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
assert len(w) == 0
def test_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
assert len(w) == 0
def test_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
assert len(w) == 1
class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas"""
def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected)
def test_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected)
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
@parameterized.expand([(100,), (440,)])
def test_pitch(self, frequency):
sample_rate = 44100
test_sine_waveform = common_utils.get_sinusoid(
frequency=frequency, sample_rate=sample_rate, duration=5,
)
freq = torchaudio.functional.detect_pitch_frequency(test_sine_waveform, sample_rate)
threshold = 1
s = ((freq - frequency).abs() > threshold).sum()
self.assertFalse(s)
class Testamplitude_to_DB(common_utils.TorchaudioTestCase):
@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_reversible(self, shape):
"""Round trip between amplitude and db should return the original for various shape
This implicitly also tests `DB_to_amplitude`.
"""
amplitude_mult = 20.
power_mult = 10.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
torch.manual_seed(0)
spec = torch.rand(*shape) * 200
# Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 0.5)
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 1.)
self.assertEqual(x2, spec)
@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_top_db_clamp(self, shape):
"""Ensure values are properly clamped when `top_db` is supplied."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.
torch.manual_seed(0)
# A random tensor is used for increased entropy, but the max and min for
# each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp.
spec = torch.rand(*shape)
# Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None]
# Expand the range to (0, 200) - wide enough to properly test clamping.
spec *= 200
decibels = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
# Ensure the clamp was applied
below_limit = decibels < 6.0205
assert not below_limit.any(), (
"{} decibel values were below the expected cutoff:\n{}".format(
below_limit.sum().item(), decibels
)
)
# Ensure it didn't over-clamp
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), (
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)
class TestComplexNorm(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(1, 2, 1025, 400, 2), (1025, 400, 2)],
[1, 2, 0.7]
)))
def test_complex_norm(self, shape, power):
torch.random.manual_seed(42)
complex_tensor = torch.randn(*shape)
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
class TestMaskAlongAxis(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400), (1, 201, 100)],
[100],
[0., 30.],
[1, 2]
)))
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgram = torch.randn(*shape)
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2
masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div(
num_masked_columns, mask_specgram.size(0), rounding_mode='floor')
assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param
class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[100],
[0., 30.],
[2, 3]
)))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400)
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
other_axis = 2 if axis == 3 else 3
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@skipIfNoSox @skipIfNoSox
class TestApplyCodec(TorchaudioTestCase): class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io" backend = "sox_io"
......
import torch import torch
import unittest import unittest
from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_impl import Functional, FunctionalComplex from .functional_impl import Functional, FunctionalComplex
@common_utils.skipIfNoCuda @skipIfNoCuda
class TestFunctionalloat32(Functional, common_utils.PytorchTestCase): class TestFunctionalFloat32(Functional, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device('cuda')
...@@ -15,21 +15,21 @@ class TestFunctionalloat32(Functional, common_utils.PytorchTestCase): ...@@ -15,21 +15,21 @@ class TestFunctionalloat32(Functional, common_utils.PytorchTestCase):
super().test_lfilter_9th_order_filter_stability() super().test_lfilter_9th_order_filter_stability()
@common_utils.skipIfNoCuda @skipIfNoCuda
class TestLFilterFloat64(Functional, common_utils.PytorchTestCase): class TestLFilterFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
@common_utils.skipIfNoCuda @skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase): class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64 complex_dtype = torch.complex64
real_dtype = torch.float32 real_dtype = torch.float32
device = torch.device('cuda') device = torch.device('cuda')
@common_utils.skipIfNoCuda @skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase): class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64 complex_dtype = torch.complex64
real_dtype = torch.float32 real_dtype = torch.float32
device = torch.device('cuda') device = torch.device('cuda')
"""Test defintion common to CPU and CUDA""" """Test definition common to CPU and CUDA"""
import math
import itertools
import warnings
import numpy as np
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
import numpy as np
from scipy import signal from scipy import signal
from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested_params
from torchaudio_unittest.common_utils import nested_params
class Functional(common_utils.TestBaseMixin): class Functional(TestBaseMixin):
def test_lfilter_simple(self): def test_lfilter_simple(self):
""" """
Create a very basic signal, Create a very basic signal,
...@@ -91,7 +94,7 @@ class Functional(common_utils.TestBaseMixin): ...@@ -91,7 +94,7 @@ class Functional(common_utils.TestBaseMixin):
assert not x.grad.isnan().sum() assert not x.grad.isnan().sum()
class FunctionalComplex(common_utils.TestBaseMixin): class FunctionalComplex(TestBaseMixin):
complex_dtype = None complex_dtype = None
real_dtype = None real_dtype = None
device = None device = None
...@@ -125,3 +128,157 @@ class FunctionalComplex(common_utils.TestBaseMixin): ...@@ -125,3 +128,157 @@ class FunctionalComplex(common_utils.TestBaseMixin):
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))]) expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape assert output_shape == expected_shape
class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
assert len(w) == 0
def test_create_fb_matrix_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
assert len(w) == 0
def test_create_fb_matrix_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
assert len(w) == 1
def test_compute_deltas_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected)
def test_compute_deltas_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected)
@parameterized.expand([(100,), (440,)])
def test_detect_pitch_frequency_pitch(self, frequency):
sample_rate = 44100
test_sine_waveform = get_sinusoid(
frequency=frequency, sample_rate=sample_rate, duration=5
)
freq = F.detect_pitch_frequency(test_sine_waveform, sample_rate)
threshold = 1
s = ((freq - frequency).abs() > threshold).sum()
self.assertFalse(s)
@parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
def test_amplitude_to_DB_reversible(self, shape):
"""Round trip between amplitude and db should return the original for various shape
This implicitly also tests `DB_to_amplitude`.
"""
amplitude_mult = 20.
power_mult = 10.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
torch.manual_seed(0)
spec = torch.rand(*shape) * 200
# Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 0.5)
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 1.)
self.assertEqual(x2, spec)
@parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
def test_amplitude_to_DB_top_db_clamp(self, shape):
"""Ensure values are properly clamped when `top_db` is supplied."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.
torch.manual_seed(0)
# A random tensor is used for increased entropy, but the max and min for
# each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp.
spec = torch.rand(*shape)
# Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None]
# Expand the range to (0, 200) - wide enough to properly test clamping.
spec *= 200
decibels = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
# Ensure the clamp was applied
below_limit = decibels < 6.0205
assert not below_limit.any(), (
"{} decibel values were below the expected cutoff:\n{}".format(
below_limit.sum().item(), decibels
)
)
# Ensure it didn't over-clamp
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), (
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)
@parameterized.expand(
list(itertools.product([(1, 2, 1025, 400, 2), (1025, 400, 2)], [1, 2, 0.7]))
)
def test_complex_norm(self, shape, power):
torch.random.manual_seed(42)
complex_tensor = torch.randn(*shape)
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
@parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
)
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgram = torch.randn(*shape)
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2
masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns = torch.div(
num_masked_columns, mask_specgram.size(0), rounding_mode='floor')
assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400)
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
other_axis = 2 if axis == 3 else 3
masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment