test_info.py 4.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import itertools
from parameterized import parameterized

from torchaudio.backend import sox_io_backend

from ..common_utils import (
    TempDirMixin,
    PytorchTestCase,
    skipIfNoExec,
    skipIfNoExtension,
)
from .common import (
    get_test_name
)
from . import sox_utils


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
    )), name_func=get_test_name)
    def test_wav(self, dtype, sample_rate, num_channels):
        """`sox_io_backend.info` can check wav file correctly"""
        duration = 1
        path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            bit_depth=sox_utils.get_bit_depth(dtype),
            encoding=sox_utils.get_encoding(dtype),
            duration=duration,
        )
        info = sox_io_backend.info(path)
        assert info.get_sample_rate() == sample_rate
38
        assert info.get_num_frames() == sample_rate * duration
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        assert info.get_num_channels() == num_channels

    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [4, 8, 16, 32],
    )), name_func=get_test_name)
    def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
        """`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
        duration = 1
        path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            bit_depth=sox_utils.get_bit_depth(dtype),
            encoding=sox_utils.get_encoding(dtype),
            duration=duration,
        )
        info = sox_io_backend.info(path)
        assert info.get_sample_rate() == sample_rate
58
        assert info.get_num_frames() == sample_rate * duration
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        assert info.get_num_channels() == num_channels

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        [96, 128, 160, 192, 224, 256, 320],
    )), name_func=get_test_name)
    def test_mp3(self, sample_rate, num_channels, bit_rate):
        """`sox_io_backend.info` can check mp3 file correctly"""
        duration = 1
        path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=bit_rate, duration=duration,
        )
        info = sox_io_backend.info(path)
        assert info.get_sample_rate() == sample_rate
        # mp3 does not preserve the number of samples
77
        # assert info.get_num_frames() == sample_rate * duration
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        assert info.get_num_channels() == num_channels

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        list(range(9)),
    )), name_func=get_test_name)
    def test_flac(self, sample_rate, num_channels, compression_level):
        """`sox_io_backend.info` can check flac file correctly"""
        duration = 1
        path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=compression_level, duration=duration,
        )
        info = sox_io_backend.info(path)
        assert info.get_sample_rate() == sample_rate
95
        assert info.get_num_frames() == sample_rate * duration
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        assert info.get_num_channels() == num_channels

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        [-1, 0, 1, 2, 3, 3.6, 5, 10],
    )), name_func=get_test_name)
    def test_vorbis(self, sample_rate, num_channels, quality_level):
        """`sox_io_backend.info` can check vorbis file correctly"""
        duration = 1
        path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=quality_level, duration=duration,
        )
        info = sox_io_backend.info(path)
        assert info.get_sample_rate() == sample_rate
113
        assert info.get_num_frames() == sample_rate * duration
114
        assert info.get_num_channels() == num_channels