Commit 92eaca7c authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Fix vctk.read_audio #143 (#145)

parent 616663ff
word1 word2
word3 word4
word5 word6
word7 word8
import os
import torch
import torchaudio
import unittest
import test.common_utils
import torchaudio.datasets.vctk as vctk
class TestVCTK(unittest.TestCase):
def setUp(self):
self.test_dirpath, self.test_dir = test.common_utils.create_temp_assets_dir()
def get_full_path(self, file):
return os.path.join(self.test_dirpath, 'assets', file)
def test_is_audio_file(self):
self.assertTrue(vctk.is_audio_file('foo.wav'))
self.assertTrue(vctk.is_audio_file('foo.WAV'))
self.assertFalse(vctk.is_audio_file('foo.bar'))
def test_make_manifest(self):
audios = vctk.make_manifest(self.test_dirpath)
files = ['kaldi_file.wav', 'kaldi_file_8000.wav',
'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3']
files = [self.get_full_path(file) for file in files]
audios.sort()
self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios))
def test_read_audio_downsample_false(self):
file = self.get_full_path('kaldi_file.wav')
s, sr = vctk.read_audio(file, downsample=False)
self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))
def test_read_audio_downsample_true(self):
file = self.get_full_path('kaldi_file.wav')
s, sr = vctk.read_audio(file, downsample=True)
self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))
def test_load_txts(self):
utterences = vctk.load_txts(self.test_dirpath)
expected_utterances = {'file2': 'word5 word6\n', 'file1': 'word1 word2\n'}
self.assertEqual(utterences, expected_utterances,
msg='%s did not match %s' % (utterences, expected_utterances))
def test_vctk(self):
# TODO somehow test download=True, the dataset is too big download ~10 GB for
# each test so need a way to mock it
self.assertRaises(RuntimeError, vctk.VCTK, self.test_dirpath, download=False)
if __name__ == '__main__':
unittest.main()
...@@ -36,13 +36,13 @@ def make_manifest(dir): ...@@ -36,13 +36,13 @@ def make_manifest(dir):
def read_audio(fp, downsample=True): def read_audio(fp, downsample=True):
if downsample: if downsample:
E = torchaudio.sox_effects.SoxEffects() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fp) E.set_input_file(fp)
E.sox_append_effect_to_chain("gain", ["-h"]) E.append_effect_to_chain("gain", ["-h"])
E.sox_append_effect_to_chain("channels", [1]) E.append_effect_to_chain("channels", [1])
E.sox_append_effect_to_chain("rate", [16000]) E.append_effect_to_chain("rate", [16000])
E.sox_append_effect_to_chain("gain", ["-rh"]) E.append_effect_to_chain("gain", ["-rh"])
E.sox_append_effect_to_chain("dither", ["-s"]) E.append_effect_to_chain("dither", ["-s"])
sig, sr = E.sox_build_flow_effects() sig, sr = E.sox_build_flow_effects()
else: else:
sig, sr = torchaudio.load(fp) sig, sr = torchaudio.load(fp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment