test_info.py 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import itertools
from parameterized import parameterized

from torchaudio.backend import sox_io_backend

from ..common_utils import (
    TempDirMixin,
    PytorchTestCase,
    skipIfNoExec,
    skipIfNoExtension,
moto's avatar
moto committed
11
    get_asset_path,
moto's avatar
moto committed
12
13
    get_wav_data,
    save_wav,
moto's avatar
moto committed
14
    sox_utils,
15
)
moto's avatar
moto committed
16
17
18
from .common import (
    name_func,
)
19
20
21
22
23
24
25
26
27


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
moto's avatar
moto committed
28
    )), name_func=name_func)
29
30
31
    def test_wav(self, dtype, sample_rate, num_channels):
        """`sox_io_backend.info` can check wav file correctly"""
        duration = 1
32
        path = self.get_temp_path('data.wav')
moto's avatar
moto committed
33
34
        data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
        save_wav(path, data, sample_rate)
35
        info = sox_io_backend.info(path)
36
37
38
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
39
40
41
42
43

    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [4, 8, 16, 32],
moto's avatar
moto committed
44
    )), name_func=name_func)
45
46
47
    def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
        """`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
        duration = 1
48
        path = self.get_temp_path('data.wav')
moto's avatar
moto committed
49
50
        data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
        save_wav(path, data, sample_rate)
51
        info = sox_io_backend.info(path)
52
53
54
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
55
56
57
58
59

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        [96, 128, 160, 192, 224, 256, 320],
moto's avatar
moto committed
60
    )), name_func=name_func)
61
62
63
    def test_mp3(self, sample_rate, num_channels, bit_rate):
        """`sox_io_backend.info` can check mp3 file correctly"""
        duration = 1
64
        path = self.get_temp_path('data.mp3')
65
66
67
68
69
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=bit_rate, duration=duration,
        )
        info = sox_io_backend.info(path)
70
        assert info.sample_rate == sample_rate
71
        # mp3 does not preserve the number of samples
72
73
        # assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
74
75
76
77
78

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        list(range(9)),
moto's avatar
moto committed
79
    )), name_func=name_func)
80
81
82
    def test_flac(self, sample_rate, num_channels, compression_level):
        """`sox_io_backend.info` can check flac file correctly"""
        duration = 1
83
        path = self.get_temp_path('data.flac')
84
85
86
87
88
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=compression_level, duration=duration,
        )
        info = sox_io_backend.info(path)
89
90
91
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
92
93
94
95
96

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        [-1, 0, 1, 2, 3, 3.6, 5, 10],
moto's avatar
moto committed
97
    )), name_func=name_func)
98
99
100
    def test_vorbis(self, sample_rate, num_channels, quality_level):
        """`sox_io_backend.info` can check vorbis file correctly"""
        duration = 1
101
        path = self.get_temp_path('data.vorbis')
102
103
104
105
106
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=quality_level, duration=duration,
        )
        info = sox_io_backend.info(path)
107
108
109
        assert info.sample_rate == sample_rate
        assert info.num_frames == sample_rate * duration
        assert info.num_channels == num_channels
moto's avatar
moto committed
110
111
112
113
114
115
116
117
118
119
120
121
122


@skipIfNoExtension
class TestInfoOpus(PytorchTestCase):
    @parameterized.expand(list(itertools.product(
        ['96k'],
        [1, 2],
        [0, 5, 10],
    )), name_func=name_func)
    def test_opus(self, bitrate, num_channels, compression_level):
        """`sox_io_backend.info` can check opus file correcty"""
        path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
        info = sox_io_backend.info(path)
123
124
125
        assert info.sample_rate == 48000
        assert info.num_frames == 32768
        assert info.num_channels == num_channels