utils_test.py 1.87 KB
Newer Older
1
2
3
4
import os
from pathlib import Path

from torchaudio.datasets import utils as dataset_utils
5
from torchaudio.datasets.commonvoice import COMMONVOICE
6
7
8
9

from ..common_utils import (
    TempDirMixin,
    TorchaudioTestCase,
10
    get_asset_path,
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
)


class TestWalkFiles(TempDirMixin, TorchaudioTestCase):
    root = None
    expected = None

    def _add_file(self, *parts):
        path = self.get_temp_path(*parts)
        self.expected.append(path)
        Path(path).touch()

    def setUp(self):
        self.root = self.get_temp_path()
        self.expected = []

        # level 1
        for filename in ['a.txt', 'b.txt', 'c.txt']:
            self._add_file(filename)

        # level 2
        for dir1 in ['d1', 'd2', 'd3']:
            for filename in ['d.txt', 'e.txt', 'f.txt']:
                self._add_file(dir1, filename)
            # level 3
            for dir2 in ['d1', 'd2', 'd3']:
                for filename in ['g.txt', 'h.txt', 'i.txt']:
                    self._add_file(dir1, dir2, filename)

        print('\n'.join(self.expected))

    def test_walk_files(self):
        """walk_files should traverse files in alphabetical order"""
        n_ites = 0
        for i, path in enumerate(dataset_utils.walk_files(self.root, '.txt', prefix=True)):
            found = os.path.join(self.root, path)
            assert found == self.expected[i]
            n_ites += 1
        assert n_ites == len(self.expected)
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


class TestIterator(TorchaudioTestCase):
    backend = 'default'
    path = get_asset_path()

    def test_disckcache_iterator(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = dataset_utils.diskcache_iterator(data)
        # Save
        data[0]
        # Load
        data[0]

    def test_bg_iterator(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = dataset_utils.bg_iterator(data, 5)
        for _ in data:
            pass