test_datasets_vctk.py 2.33 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
jamarshon's avatar
jamarshon committed
2
3
4
5
6
import os

import torch
import torchaudio
import unittest
7
import common_utils
jamarshon's avatar
jamarshon committed
8
9
10
11
12
import torchaudio.datasets.vctk as vctk


class TestVCTK(unittest.TestCase):
    def setUp(self):
13
        self.test_dirpath, self.test_dir = common_utils.create_temp_assets_dir()
jamarshon's avatar
jamarshon committed
14
15
16
17
18
19
20
21
22
23
24
25

    def get_full_path(self, file):
        return os.path.join(self.test_dirpath, 'assets', file)

    def test_is_audio_file(self):
        self.assertTrue(vctk.is_audio_file('foo.wav'))
        self.assertTrue(vctk.is_audio_file('foo.WAV'))
        self.assertFalse(vctk.is_audio_file('foo.bar'))

    def test_make_manifest(self):
        audios = vctk.make_manifest(self.test_dirpath)
        files = ['kaldi_file.wav', 'kaldi_file_8000.wav',
26
27
                 'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3',
                 'dtmf_30s_stereo.mp3', 'whitenoise_1min.mp3', 'whitenoise.mp3']
jamarshon's avatar
jamarshon committed
28
29
        files = [self.get_full_path(file) for file in files]

30
        files.sort()
jamarshon's avatar
jamarshon committed
31
        audios.sort()
32

jamarshon's avatar
jamarshon committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios))

    def test_read_audio_downsample_false(self):
        file = self.get_full_path('kaldi_file.wav')
        s, sr = vctk.read_audio(file, downsample=False)
        self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
        self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))

    def test_read_audio_downsample_true(self):
        file = self.get_full_path('kaldi_file.wav')
        s, sr = vctk.read_audio(file, downsample=True)
        self.assertEqual(sr, 16000, msg='incorrect sample rate %d' % (sr))
        self.assertEqual(s.shape, (1, 20), msg='incorrect shape %s' % (str(s.shape)))

    def test_load_txts(self):
        utterences = vctk.load_txts(self.test_dirpath)
        expected_utterances = {'file2': 'word5 word6\n', 'file1': 'word1 word2\n'}
        self.assertEqual(utterences, expected_utterances,
                         msg='%s did not match %s' % (utterences, expected_utterances))

    def test_vctk(self):
        # TODO somehow test download=True, the dataset is too big download ~10 GB for
        # each test so need a way to mock it
        self.assertRaises(RuntimeError, vctk.VCTK, self.test_dirpath, download=False)

if __name__ == '__main__':
    unittest.main()