test_dataloader.py 1.68 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
David Pollack's avatar
David Pollack committed
2
import unittest
3
import common_utils
David Pollack's avatar
David Pollack committed
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import math
import os


class TORCHAUDIODS(Dataset):

14
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
David Pollack's avatar
David Pollack committed
15
16
17

    def __init__(self):
        self.asset_dirpath = os.path.join(self.test_dirpath, "assets")
18
        sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
19
        self.data = [os.path.join(self.asset_dirpath, fn) for fn in sound_files]
David Pollack's avatar
David Pollack committed
20
21
        self.si, self.ei = torchaudio.info(os.path.join(self.asset_dirpath, "sinewave.wav"))
        self.si.precision = 16
David Pollack's avatar
David Pollack committed
22
23
        self.E = torchaudio.sox_effects.SoxEffectsChain()
        self.E.append_effect_to_chain("rate", [self.si.rate])  # resample to 16000hz
24
        self.E.append_effect_to_chain("channels", [self.si.channels])  # mono signal
David Pollack's avatar
David Pollack committed
25
        self.E.append_effect_to_chain("trim", [0, "16000s"])  # first 16000 samples of audio
David Pollack's avatar
David Pollack committed
26
27
28
29
30
31
32
33
34
35

    def __getitem__(self, index):
        fn = self.data[index]
        self.E.set_input_file(fn)
        x, sr = self.E.sox_build_flow_effects()
        return x

    def __len__(self):
        return len(self.data)

36

37
class Test_DataLoader(unittest.TestCase):
38
39
40
41
42
43
44
45
    @classmethod
    def setUpClass(cls):
        torchaudio.initialize_sox()

    @classmethod
    def tearDownClass(cls):
        torchaudio.shutdown_sox()

David Pollack's avatar
David Pollack committed
46
47
48
49
50
    def test_1(self):
        expected_size = (2, 1, 16000)
        ds = TORCHAUDIODS()
        dl = DataLoader(ds, batch_size=2)
        for x in dl:
51
            self.assertTrue(x.size() == expected_size)
David Pollack's avatar
David Pollack committed
52
53
54

if __name__ == '__main__':
    unittest.main()