Commit cba11009 authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Soumith Chintala
Browse files

Added test to make sure loading and saving gives the same file

parent 8a41ecdc
......@@ -7,8 +7,8 @@ import os
class Test_LoadSave(unittest.TestCase):
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(
test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3")
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
def test_load(self):
# check normal loading
......@@ -16,7 +16,6 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (278756, 2))
self.assertGreater(x.sum(), 0)
print
# check normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True)
......@@ -28,8 +27,8 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.load("file-does-not-exist.mp3")
with self.assertRaises(OSError):
tdir = os.path.join(os.path.dirname(
self.test_dirpath), "torchaudio")
tdir = os.path.join(
os.path.dirname(self.test_dirpath), "torchaudio")
torchaudio.load(tdir)
def test_save(self):
......@@ -80,24 +79,35 @@ class Test_LoadSave(unittest.TestCase):
# don't save to folders that don't exist
with self.assertRaises(OSError):
new_filepath = os.path.join(
self.test_dirpath, "no-path", "test.wav")
new_filepath = os.path.join(self.test_dirpath, "no-path",
"test.wav")
torchaudio.save(new_filepath, x, sr)
# save created file
sinewave_filepath = os.path.join(
self.test_dirpath, "assets", "sinewave.wav")
sinewave_filepath = os.path.join(self.test_dirpath, "assets",
"sinewave.wav")
sr = 16000
freq = 440
volume = 0.3
y = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y = (torch.cos(
2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y.unsqueeze_(1)
# y is between -1 and 1, so must scale
y = (y * volume * 2**31).long()
torchaudio.save(sinewave_filepath, y, sr)
self.assertTrue(os.path.isfile(sinewave_filepath))
def test_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor, sample_rate)
tensor2, sample_rate2 = torchaudio.load(output_path)
self.assertTrue(tensor.allclose(tensor2))
self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment