Unverified Commit 41c76a17 authored by moto's avatar moto Committed by GitHub
Browse files

Support file-like object in info (#1108)

parent 22e7e877
from unittest.mock import patch
import warnings
import tarfile
import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
......@@ -125,3 +126,65 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, 'rb') as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16)
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16)
def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16)
@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16)
import io
import itertools
from parameterized import parameterized
import tarfile
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
get_asset_path,
get_wav_data,
save_wav,
......@@ -18,6 +23,10 @@ from .common import (
)
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
......@@ -197,3 +206,143 @@ class TestLoadWithoutExtension(PytorchTestCase):
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_fileobj(self, ext, bits_per_sample):
"""Querying audio via file object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)
with open(path, 'rb') as fileobj:
sinfo = sox_io_backend.info(fileobj, format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
def _test_bytesio(self, ext, bits_per_sample, duration):
sample_rate = 16000
num_channels = 2
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
sinfo = sox_io_backend.info(fileobj, format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works"""
self._test_bytesio(ext, bits_per_sample, duration=3)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio_tiny(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works for small data"""
self._test_bytesio(ext, bits_per_sample, duration=1 / 1600)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_tarfile(self, ext, bits_per_sample):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
sinfo = sox_io_backend.info(fileobj, format=format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_requests(self, ext, bits_per_sample):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
sinfo = sox_io_backend.info(resp.raw, format=format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
......@@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
format (str, optional):
Not used. PySoundFile does not accept format hint.
......
......@@ -10,6 +10,26 @@ import torchaudio
from .common import AudioMetaData
@torch.jit.unused
def _info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(
filepath, format)
sample_rate, num_channels, num_frames, bits_per_sample = sinfo
return AudioMetaData(
sample_rate, num_frames, num_channels, bits_per_sample)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
)
@_mod_utils.requires_module('torchaudio._torchaudio')
def info(
filepath: str,
......@@ -18,9 +38,21 @@ def info(
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects,
but is annotated as ``str`` for TorchScript compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
......@@ -29,11 +61,14 @@ def info(
Returns:
AudioMetaData: Metadata of the given audio.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
return _info(filepath, format)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(),
sinfo.get_bits_per_sample())
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample())
@_mod_utils.requires_module('torchaudio._torchaudio')
......
......@@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def(
"get_info_fileobj",
&torchaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
......
......@@ -36,7 +36,7 @@ int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample;
}
c10::intrusive_ptr<SignalInfo> get_info(
c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path,
c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
......@@ -149,6 +149,56 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format) {
// Prepare in-memory file object
// When libsox opens a file, it also reads the header.
// When opening a file there are two functions that might touch FILE* (and the
// underlying buffer).
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
//
// `auto_detect_format` function only requires 256 bytes, but format-dependant
// `startread` handler might require more data. In case of vorbis, the size of
// header is unbounded, but typically 4kB maximum.
//
// "The header size is unbounded, although for streaming a rule-of-thumb of
// 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this
// suggestion)."
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto capacity = 4096;
std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_file(sf, /*check_length=*/false);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample));
}
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t>& frame_offset,
......
......@@ -28,7 +28,7 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t getBitsPerSample() const;
};
c10::intrusive_ptr<SignalInfo> get_info(
c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path,
c10::optional<std::string>& format);
......@@ -50,6 +50,10 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format);
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t>& frame_offset,
......
......@@ -47,7 +47,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file("
"str path,"
......
......@@ -317,5 +317,37 @@ sox_encodinginfo_t get_encodinginfo(
/*opposite_endian=*/sox_false};
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
std::cerr << "req: " << request << ", fetched: " << chunk_len << std::endl;
std::cerr << "buffer: " << (void*)buffer << std::endl;
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_utils
} // namespace torchaudio
......@@ -4,6 +4,10 @@
#include <sox.h>
#include <torch/script.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace torchaudio {
namespace sox_utils {
......@@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo(
const caffe2::TypeMeta dtype,
c10::optional<double>& compression);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t read_fileobj(py::object* fileobj, uint64_t size, char* buffer);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_utils
} // namespace torchaudio
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment