test_dataloader.py 1.39 KB
Newer Older
David Pollack's avatar
David Pollack committed
1
import unittest
2

David Pollack's avatar
David Pollack committed
3
import torchaudio
4
5
6
7
from torch.utils.data import Dataset, DataLoader

import common_utils
from common_utils import AudioBackendScope, BACKENDS
David Pollack's avatar
David Pollack committed
8
9
10
11


class TORCHAUDIODS(Dataset):
    def __init__(self):
12
        sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
13
14
        self.data = [common_utils.get_asset_path(fn) for fn in sound_files]
        self.si, self.ei = torchaudio.info(common_utils.get_asset_path("sinewave.wav"))
David Pollack's avatar
David Pollack committed
15
        self.si.precision = 16
David Pollack's avatar
David Pollack committed
16
17
        self.E = torchaudio.sox_effects.SoxEffectsChain()
        self.E.append_effect_to_chain("rate", [self.si.rate])  # resample to 16000hz
18
        self.E.append_effect_to_chain("channels", [self.si.channels])  # mono signal
David Pollack's avatar
David Pollack committed
19
        self.E.append_effect_to_chain("trim", [0, "16000s"])  # first 16000 samples of audio
David Pollack's avatar
David Pollack committed
20
21
22
23
24
25
26
27
28
29

    def __getitem__(self, index):
        fn = self.data[index]
        self.E.set_input_file(fn)
        x, sr = self.E.sox_build_flow_effects()
        return x

    def __len__(self):
        return len(self.data)

30

31
class Test_DataLoader(unittest.TestCase):
32
33
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
David Pollack's avatar
David Pollack committed
34
35
36
37
38
    def test_1(self):
        expected_size = (2, 1, 16000)
        ds = TORCHAUDIODS()
        dl = DataLoader(ds, batch_size=2)
        for x in dl:
39
            self.assertTrue(x.size() == expected_size)
David Pollack's avatar
David Pollack committed
40

41

David Pollack's avatar
David Pollack committed
42
if __name__ == '__main__':
43
    unittest.main()