import itertools from parameterized import parameterized from torchaudio.backend import sox_io_backend from ..common_utils import ( TempDirMixin, PytorchTestCase, skipIfNoExec, skipIfNoExtension, sox_utils, get_wav_data, save_wav, ) from .common import ( name_func, ) @skipIfNoExec('sox') @skipIfNoExtension class TestInfo(TempDirMixin, PytorchTestCase): @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [1, 2], )), name_func=name_func) def test_wav(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file correctly""" duration = 1 path = self.get_temp_path('data.wav') data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], [8000, 16000], [4, 8, 16, 32], )), name_func=name_func) def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" duration = 1 path = self.get_temp_path('data.wav') data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], [96, 128, 160, 192, 224, 256, 320], )), name_func=name_func) def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.info` can check mp3 file correctly""" duration = 1 path = self.get_temp_path('data.mp3') sox_utils.gen_audio_file( path, sample_rate, num_channels, compression=bit_rate, duration=duration, ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate # mp3 does not preserve the number of samples # assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], list(range(9)), )), name_func=name_func) def test_flac(self, sample_rate, num_channels, compression_level): """`sox_io_backend.info` can check flac file correctly""" duration = 1 path = self.get_temp_path('data.flac') sox_utils.gen_audio_file( path, sample_rate, num_channels, compression=compression_level, duration=duration, ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], [-1, 0, 1, 2, 3, 3.6, 5, 10], )), name_func=name_func) def test_vorbis(self, sample_rate, num_channels, quality_level): """`sox_io_backend.info` can check vorbis file correctly""" duration = 1 path = self.get_temp_path('data.vorbis') sox_utils.gen_audio_file( path, sample_rate, num_channels, compression=quality_level, duration=duration, ) info = sox_io_backend.info(path) assert info.get_sample_rate() == sample_rate assert info.get_num_frames() == sample_rate * duration assert info.get_num_channels() == num_channels