Unverified Commit bc6b7f97 authored by moto's avatar moto Committed by GitHub
Browse files

Adopt PyTorch's test util to transforms test (#652)

parent ac7c052f
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import unittest import unittest
import torch import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio import torchaudio
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
import torchaudio.functional as F import torchaudio.functional as F
...@@ -9,7 +10,7 @@ import torchaudio.functional as F ...@@ -9,7 +10,7 @@ import torchaudio.functional as F
import common_utils import common_utils
class Tester(unittest.TestCase): class Tester(TestCase):
# create a sinewave signal for testing # create a sinewave signal for testing
sample_rate = 16000 sample_rate = 16000
...@@ -49,7 +50,7 @@ class Tester(unittest.TestCase): ...@@ -49,7 +50,7 @@ class Tester(unittest.TestCase):
mag_to_db_torch = mag_to_db_transform(torch.abs(waveform)) mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2)) power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))
torch.testing.assert_allclose(mag_to_db_torch, power_to_db_torch) self.assertEqual(mag_to_db_torch, power_to_db_torch)
def test_melscale_load_save(self): def test_melscale_load_save(self):
specgram = torch.ones(1, 1000, 100) specgram = torch.ones(1, 1000, 100)
...@@ -63,7 +64,7 @@ class Tester(unittest.TestCase): ...@@ -63,7 +64,7 @@ class Tester(unittest.TestCase):
fb_copy = melscale_transform_copy.fb fb_copy = melscale_transform_copy.fb
self.assertEqual(fb_copy.size(), (1000, 128)) self.assertEqual(fb_copy.size(), (1000, 128))
torch.testing.assert_allclose(fb, fb_copy) self.assertEqual(fb, fb_copy)
def test_melspectrogram_load_save(self): def test_melspectrogram_load_save(self):
waveform = self.waveform.float() waveform = self.waveform.float()
...@@ -79,10 +80,10 @@ class Tester(unittest.TestCase): ...@@ -79,10 +80,10 @@ class Tester(unittest.TestCase):
fb = mel_spectrogram_transform.mel_scale.fb fb = mel_spectrogram_transform.mel_scale.fb
fb_copy = mel_spectrogram_transform_copy.mel_scale.fb fb_copy = mel_spectrogram_transform_copy.mel_scale.fb
torch.testing.assert_allclose(window, window_copy) self.assertEqual(window, window_copy)
# the default for n_fft = 400 and n_mels = 128 # the default for n_fft = 400 and n_mels = 128
self.assertEqual(fb_copy.size(), (201, 128)) self.assertEqual(fb_copy.size(), (201, 128))
torch.testing.assert_allclose(fb, fb_copy) self.assertEqual(fb, fb_copy)
def test_mel2(self): def test_mel2(self):
top_db = 80. top_db = 80.
...@@ -205,7 +206,7 @@ class Tester(unittest.TestCase): ...@@ -205,7 +206,7 @@ class Tester(unittest.TestCase):
computed_transform = transform(specgram) computed_transform = transform(specgram)
computed_functional = F.compute_deltas(specgram, win_length=win_length) computed_functional = F.compute_deltas(specgram, win_length=win_length)
torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol) self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol)
def test_compute_deltas_twochannel(self): def test_compute_deltas_twochannel(self):
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1) specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
...@@ -214,7 +215,7 @@ class Tester(unittest.TestCase): ...@@ -214,7 +215,7 @@ class Tester(unittest.TestCase):
transform = transforms.ComputeDeltas(win_length=3) transform = transforms.ComputeDeltas(win_length=3)
computed = transform(specgram) computed = transform(specgram)
assert computed.shape == expected.shape, (computed.shape, expected.shape) assert computed.shape == expected.shape, (computed.shape, expected.shape)
torch.testing.assert_allclose(computed, expected, atol=1e-6, rtol=1e-8) self.assertEqual(computed, expected, atol=1e-6, rtol=1e-8)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment