Unverified Commit d7c79b39 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Add test transform and functional AmplitudeToDB (#463)

parent 933f6037
...@@ -573,7 +573,6 @@ class TestFunctional(unittest.TestCase): ...@@ -573,7 +573,6 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate) _test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_torchscript_amplitude_to_DB(self): def test_torchscript_amplitude_to_DB(self):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
multiplier = 10.0 multiplier = 10.0
amin = 1e-10 amin = 1e-10
...@@ -582,6 +581,32 @@ class TestFunctional(unittest.TestCase): ...@@ -582,6 +581,32 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db) _test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available')
def test_amplitude_to_DB(self):
spec = torch.rand((6, 201))
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
# Power to DB
multiplier = 10.0
ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
lr_out = librosa.core.power_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out).unsqueeze(0)
self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
# Amplitude to DB
multiplier = 20.0
ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
lr_out = librosa.core.amplitude_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out).unsqueeze(0)
self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
def test_torchscript_create_dct(self): def test_torchscript_create_dct(self):
n_mfcc = 40 n_mfcc = 40
......
...@@ -89,6 +89,29 @@ class Tester(unittest.TestCase): ...@@ -89,6 +89,29 @@ class Tester(unittest.TestCase):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
_test_script_module(transforms.AmplitudeToDB, spec) _test_script_module(transforms.AmplitudeToDB, spec)
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
# Single then transform then batch
expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
# Batch then transform
computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_AmplitudeToDB(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
power_to_db_transform = transforms.AmplitudeToDB('power', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))
self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))
def test_scriptmodule_MelScale(self): def test_scriptmodule_MelScale(self):
spec_f = torch.rand((1, 6, 201)) spec_f = torch.rand((1, 6, 201))
_test_script_module(transforms.MelScale, spec_f) _test_script_module(transforms.MelScale, spec_f)
...@@ -239,16 +262,24 @@ class Tester(unittest.TestCase): ...@@ -239,16 +262,24 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db # test s2db
db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.) power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu() power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa) power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) self.assertTrue(torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu() mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
self.assertTrue(
torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa) db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertTrue(
self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) torch.allclose(power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
)
# test MFCC # test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft} melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment