Unverified Commit c539ad7d authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Refactored tests for backend (#1239)

parent f5377999
from torchaudio_unittest.common_utils import sox_utils
def get_encoding(ext, dtype):
exts = {
'mp3',
'flac',
'vorbis',
}
encodings = {
'float32': 'PCM_F',
'int32': 'PCM_S',
'int16': 'PCM_S',
'uint8': 'PCM_U',
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
'flac': 24,
'mp3': 0,
'vorbis': 0,
}
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
......@@ -13,8 +13,10 @@ from torchaudio_unittest.common_utils import (
get_wav_data,
save_wav,
)
# TODO refactor and move these to common location
from torchaudio_unittest.sox_io_backend.info_test import get_encoding, get_bits_per_sample
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"):
......
......@@ -7,6 +7,10 @@ from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
......@@ -28,30 +32,6 @@ if _mod_utils.is_module_available("requests"):
import requests
def get_encoding(ext, dtype):
exts = {
'mp3',
'flac',
'vorbis',
}
encodings = {
'float32': 'PCM_F',
'int32': 'PCM_S',
'int16': 'PCM_S',
'uint8': 'PCM_U',
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
'flac': 24,
'mp3': 0,
'vorbis': 0,
}
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
......@@ -161,7 +141,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.info` can check sph file correctly"""
duration = 1
path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample)
sox_utils.gen_audio_file(
path, sample_rate, num_channels, duration=duration,
bit_depth=bits_per_sample)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment