Unverified Commit c539ad7d authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Refactored tests for backend (#1239)

parent f5377999
from torchaudio_unittest.common_utils import sox_utils
def get_encoding(ext, dtype):
exts = {
'mp3',
'flac',
'vorbis',
}
encodings = {
'float32': 'PCM_F',
'int32': 'PCM_S',
'int16': 'PCM_S',
'uint8': 'PCM_U',
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
'flac': 24,
'mp3': 0,
'vorbis': 0,
}
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
...@@ -13,8 +13,10 @@ from torchaudio_unittest.common_utils import ( ...@@ -13,8 +13,10 @@ from torchaudio_unittest.common_utils import (
get_wav_data, get_wav_data,
save_wav, save_wav,
) )
# TODO refactor and move these to common location from torchaudio_unittest.backend.common import (
from torchaudio_unittest.sox_io_backend.info_test import get_encoding, get_bits_per_sample get_bits_per_sample,
get_encoding,
)
from .common import skipIfFormatNotSupported, parameterize from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"): if _mod_utils.is_module_available("soundfile"):
......
...@@ -7,6 +7,10 @@ from parameterized import parameterized ...@@ -7,6 +7,10 @@ from parameterized import parameterized
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
HttpServerMixin, HttpServerMixin,
...@@ -28,30 +32,6 @@ if _mod_utils.is_module_available("requests"): ...@@ -28,30 +32,6 @@ if _mod_utils.is_module_available("requests"):
import requests import requests
def get_encoding(ext, dtype):
exts = {
'mp3',
'flac',
'vorbis',
}
encodings = {
'float32': 'PCM_F',
'int32': 'PCM_S',
'int16': 'PCM_S',
'uint8': 'PCM_U',
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
'flac': 24,
'mp3': 0,
'vorbis': 0,
}
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
...@@ -161,7 +141,9 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -161,7 +141,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.info` can check sph file correctly""" """`sox_io_backend.info` can check sph file correctly"""
duration = 1 duration = 1
path = self.get_temp_path('data.sph') path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample) sox_utils.gen_audio_file(
path, sample_rate, num_channels, duration=duration,
bit_depth=bits_per_sample)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment