Unverified Commit b152ee61 authored by moto's avatar moto Committed by GitHub
Browse files

Add encoding attribute to AudioMetaData (#1206)

parent 674a71d1
......@@ -13,6 +13,8 @@ from torchaudio_unittest.common_utils import (
get_wav_data,
save_wav,
)
# TODO refactor and move these to common location
from torchaudio_unittest.sox_io_backend.info_test import get_encoding, get_bits_per_sample
from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"):
......@@ -22,11 +24,10 @@ if _mod_utils.is_module_available("soundfile"):
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize(
[("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2],
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2],
)
def test_wav(self, dtype_and_bit_depth, sample_rate, num_channels):
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
dtype, bits_per_sample = dtype_and_bit_depth
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
......@@ -37,25 +38,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
@parameterize(
[("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2],
)
def test_wav_multiple_channels(self, dtype_and_bit_depth, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
dtype, bits_per_sample = dtype_and_bit_depth
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
dtype, num_channels, normalize=False, num_frames=duration * sample_rate
)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
......@@ -72,6 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "FLAC"
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("OGG")
......@@ -88,6 +73,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "VORBIS"
@parameterize([8000, 16000], [1, 2], [('PCM_24', 24), ('PCM_32', 32)])
@skipIfFormatNotSupported("NIST")
......@@ -105,6 +91,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
def test_unknown_subtype_warning(self):
"""soundfile_backend.info issues a warning when the subtype is unknown
......@@ -118,6 +105,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
frames = 356
channels = 2
subtype = 'UNSEEN_SUBTYPE'
format = 'UNKNOWN'
return MockSoundFileInfo()
with patch("soundfile.info", _mock_info_func):
......@@ -147,6 +135,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S"
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
......@@ -179,6 +168,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S"
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
......
import io
import os
import itertools
import tarfile
......@@ -27,6 +28,30 @@ if _mod_utils.is_module_available("requests"):
import requests
def get_encoding(ext, dtype):
exts = {
'mp3',
'flac',
'vorbis',
}
encodings = {
'float32': 'PCM_F',
'int32': 'PCM_S',
'int16': 'PCM_S',
'uint8': 'PCM_U',
}
return ext.upper() if ext in exts else encodings[dtype]
def get_bits_per_sample(ext, dtype):
bits_per_samples = {
'flac': 24,
'mp3': 0,
'vorbis': 0,
}
return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype))
@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
......@@ -46,6 +71,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
......@@ -63,6 +89,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -83,6 +110,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
# assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "MP3"
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -102,6 +130,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 24 # FLAC standard
assert info.encoding == "FLAC"
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -121,6 +150,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "VORBIS"
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -137,9 +167,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
['int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
......@@ -156,6 +187,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == get_encoding("amb", dtype)
def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
......@@ -171,6 +203,41 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "AMR_NB"
def test_ulaw(self):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.wav')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=8, encoding='u-law',
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ULAW"
def test_alaw(self):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.wav')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=8, encoding='a-law',
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ALAW"
@skipIfNoExtension
......@@ -188,6 +255,7 @@ class TestInfoOpus(PytorchTestCase):
assert info.num_frames == 32768
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "OPUS"
@skipIfNoExtension
......@@ -205,144 +273,193 @@ class TestLoadWithoutExtension(PytorchTestCase):
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 81216
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self.get_temp_path(f'test.{ext}')
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
sox_utils.gen_audio_file(
path, sample_rate, num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration)
return path
@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as fileobj:
return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path('archive.tar.gz')
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ['mp3'] else None
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_fileobj(self, ext, bits_per_sample):
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
with open(path, 'rb') as fileobj:
sinfo = sox_io_backend.info(fileobj, format_)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
def _test_bytesio(self, ext, bits_per_sample, duration):
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
sinfo = sox_io_backend.info(fileobj, format_)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_bytesio(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works"""
self._test_bytesio(ext, bits_per_sample, duration=3)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio_tiny(self, ext, bits_per_sample):
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
self._test_bytesio(ext, bits_per_sample, duration=1 / 1600)
sample_rate = 8000
num_frames = 4
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_tarfile(self, ext, bits_per_sample):
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
sinfo = sox_io_backend.info(fileobj, format=format_)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
url = self.get_url(audio_file)
format_ = ext if ext in ['mp3'] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_requests(self, ext, bits_per_sample):
def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
sinfo = sox_io_backend.info(resp.raw, format=format_)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
......@@ -50,6 +50,37 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
}
def _get_bit_depth(subtype):
if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
_SUBTYPE_TO_ENCODING = {
'PCM_S8': 'PCM_S',
'PCM_16': 'PCM_S',
'PCM_24': 'PCM_S',
'PCM_32': 'PCM_S',
'PCM_U8': 'PCM_U',
'FLOAT': 'PCM_F',
'DOUBLE': 'PCM_F',
'ULAW': 'ULAW',
'ALAW': 'ALAW',
'VORBIS': 'VORBIS',
}
def _get_encoding(format: str, subtype: str):
if format == 'FLAC':
return 'FLAC'
return _SUBTYPE_TO_ENCODING.get(subtype, 'UNKNOWN')
@_mod_utils.requires_module("soundfile")
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
......@@ -68,15 +99,13 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
AudioMetaData: meta data of the given audio.
"""
sinfo = soundfile.info(filepath)
if sinfo.subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {sinfo.subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
bits_per_sample = _SUBTYPE_TO_BITS_PER_SAMPLE.get(sinfo.subtype, 0)
return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels, bits_per_sample=bits_per_sample)
return AudioMetaData(
sinfo.samplerate,
sinfo.frames,
sinfo.channels,
bits_per_sample=_get_bit_depth(sinfo.subtype),
encoding=_get_encoding(sinfo.format, sinfo.subtype),
)
_SUBTYPE2DTYPE = {
......
......@@ -14,12 +14,21 @@ class AudioMetaData:
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
:ivar str encoding: Audio encoding.
"""
def __init__(self, sample_rate: int, num_frames: int, num_channels: int, bits_per_sample: int):
def __init__(
self,
sample_rate: int,
num_frames: int,
num_channels: int,
bits_per_sample: int,
encoding: str,
):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample
self.encoding = encoding
@_mod_utils.deprecated('Please migrate to `AudioMetaData`.', '0.9.0')
......
......@@ -17,17 +17,15 @@ def _info(
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(
filepath, format)
sample_rate, num_channels, num_frames, bits_per_sample = sinfo
return AudioMetaData(
sample_rate, num_frames, num_channels, bits_per_sample)
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
sinfo.get_encoding(),
)
......@@ -69,7 +67,8 @@ def info(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample())
sinfo.get_bits_per_sample(),
sinfo.get_encoding())
@_mod_utils.requires_module('torchaudio._torchaudio')
......
......@@ -14,11 +14,13 @@ SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_,
const int64_t bits_per_sample_)
const int64_t bits_per_sample_,
const std::string encoding_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_),
bits_per_sample(bits_per_sample_){};
bits_per_sample(bits_per_sample_),
encoding(encoding_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
......@@ -36,6 +38,45 @@ int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample;
}
std::string SignalInfo::getEncoding() const {
return encoding;
}
namespace {
std::string get_encoding(sox_encoding_t encoding) {
switch (encoding) {
case SOX_ENCODING_UNKNOWN:
return "UNKNOWN";
case SOX_ENCODING_SIGN2:
return "PCM_S";
case SOX_ENCODING_UNSIGNED:
return "PCM_U";
case SOX_ENCODING_FLOAT:
return "PCM_F";
case SOX_ENCODING_FLAC:
return "FLAC";
case SOX_ENCODING_ULAW:
return "ULAW";
case SOX_ENCODING_ALAW:
return "ALAW";
case SOX_ENCODING_MP3:
return "MP3";
case SOX_ENCODING_VORBIS:
return "VORBIS";
case SOX_ENCODING_AMR_WB:
return "AMR_WB";
case SOX_ENCODING_AMR_NB:
return "AMR_NB";
case SOX_ENCODING_OPUS:
return "OPUS";
default:
return "UNKNOWN";
}
}
} // namespace
c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path,
c10::optional<std::string>& format) {
......@@ -53,7 +94,8 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample));
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
namespace {
......@@ -157,7 +199,7 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format) {
// Prepare in-memory file object
......@@ -202,9 +244,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample));
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
......
......@@ -16,16 +16,19 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t num_channels;
int64_t num_frames;
int64_t bits_per_sample;
std::string encoding;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_,
const int64_t bits_per_sample_);
const int64_t bits_per_sample_,
const std::string encoding_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
int64_t getBitsPerSample() const;
std::string getEncoding() const;
};
c10::intrusive_ptr<SignalInfo> get_info_file(
......@@ -51,7 +54,7 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format);
......
......@@ -45,7 +45,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample);
&torchaudio::sox_io::SignalInfo::getBitsPerSample)
.def("get_encoding", &torchaudio::sox_io::SignalInfo::getEncoding);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment