"tests/vscode:/vscode.git/clone" did not exist on "f1689ad0e12c2d6f4b00b7564b9b81dcc1301a39"
Commit cba11009 authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Soumith Chintala
Browse files

Added test to make sure loading and saving gives the same file

parent 8a41ecdc
......@@ -7,8 +7,8 @@ import os
class Test_LoadSave(unittest.TestCase):
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(
test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3")
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
def test_load(self):
# check normal loading
......@@ -16,7 +16,6 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (278756, 2))
self.assertGreater(x.sum(), 0)
print
# check normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True)
......@@ -28,8 +27,8 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.load("file-does-not-exist.mp3")
with self.assertRaises(OSError):
tdir = os.path.join(os.path.dirname(
self.test_dirpath), "torchaudio")
tdir = os.path.join(
os.path.dirname(self.test_dirpath), "torchaudio")
torchaudio.load(tdir)
def test_save(self):
......@@ -80,24 +79,35 @@ class Test_LoadSave(unittest.TestCase):
# don't save to folders that don't exist
with self.assertRaises(OSError):
new_filepath = os.path.join(
self.test_dirpath, "no-path", "test.wav")
new_filepath = os.path.join(self.test_dirpath, "no-path",
"test.wav")
torchaudio.save(new_filepath, x, sr)
# save created file
sinewave_filepath = os.path.join(
self.test_dirpath, "assets", "sinewave.wav")
sinewave_filepath = os.path.join(self.test_dirpath, "assets",
"sinewave.wav")
sr = 16000
freq = 440
volume = 0.3
y = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y = (torch.cos(
2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y.unsqueeze_(1)
# y is between -1 and 1, so must scale
y = (y * volume * 2**31).long()
torchaudio.save(sinewave_filepath, y, sr)
self.assertTrue(os.path.isfile(sinewave_filepath))
def test_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor, sample_rate)
tensor2, sample_rate2 = torchaudio.load(output_path)
self.assertTrue(tensor.allclose(tensor2))
self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment