Unverified Commit 41c76a17 authored by moto's avatar moto Committed by GitHub
Browse files

Support file-like object in info (#1108)

parent 22e7e877
from unittest.mock import patch from unittest.mock import patch
import warnings import warnings
import tarfile
import torch import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend from torchaudio.backend import _soundfile_backend as soundfile_backend
...@@ -125,3 +126,65 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -125,3 +126,65 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert len(w) == 1 assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message) assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0 assert info.bits_per_sample == 0
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, 'rb') as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16)
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16)
def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16)
@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16)
import io
import itertools import itertools
from parameterized import parameterized import tarfile
from parameterized import parameterized
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
HttpServerMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
skipIfNoModule,
get_asset_path, get_asset_path,
get_wav_data, get_wav_data,
save_wav, save_wav,
...@@ -18,6 +23,10 @@ from .common import ( ...@@ -18,6 +23,10 @@ from .common import (
) )
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
...@@ -197,3 +206,143 @@ class TestLoadWithoutExtension(PytorchTestCase): ...@@ -197,3 +206,143 @@ class TestLoadWithoutExtension(PytorchTestCase):
sinfo = sox_io_backend.info(path, format="mp3") sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000 assert sinfo.sample_rate == 16000
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_fileobj(self, ext, bits_per_sample):
"""Querying audio via file object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)
with open(path, 'rb') as fileobj:
sinfo = sox_io_backend.info(fileobj, format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
def _test_bytesio(self, ext, bits_per_sample, duration):
sample_rate = 16000
num_channels = 2
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
sinfo = sox_io_backend.info(fileobj, format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works"""
self._test_bytesio(ext, bits_per_sample, duration=3)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio_tiny(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works for small data"""
self._test_bytesio(ext, bits_per_sample, duration=1 / 1600)
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_tarfile(self, ext, bits_per_sample):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
sinfo = sox_io_backend.info(fileobj, format=format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_requests(self, ext, bits_per_sample):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
sinfo = sox_io_backend.info(resp.raw, format=format_)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
...@@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData: ...@@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file. """Get signal information of an audio file.
Args: Args:
filepath (str or pathlib.Path): Path to audio file. filepath (path-like object or file-like object):
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str`` Source of audio data.
for the consistency with "sox_io" backend, which has a restriction on type annotation Note:
for TorchScript compiler compatiblity. * This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
format (str, optional): format (str, optional):
Not used. PySoundFile does not accept format hint. Not used. PySoundFile does not accept format hint.
......
...@@ -10,6 +10,26 @@ import torchaudio ...@@ -10,6 +10,26 @@ import torchaudio
from .common import AudioMetaData from .common import AudioMetaData
@torch.jit.unused
def _info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(
filepath, format)
sample_rate, num_channels, num_frames, bits_per_sample = sinfo
return AudioMetaData(
sample_rate, num_frames, num_channels, bits_per_sample)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
def info( def info(
filepath: str, filepath: str,
...@@ -18,9 +38,21 @@ def info( ...@@ -18,9 +38,21 @@ def info(
"""Get signal information of an audio file. """Get signal information of an audio file.
Args: Args:
filepath (str or pathlib.Path): filepath (path-like object or file-like object):
Path to audio file. This function also handles ``pathlib.Path`` objects, Source of audio data. When the function is not compiled by TorchScript,
but is annotated as ``str`` for TorchScript compatibility. (e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str, optional): format (str, optional):
Override the format detection with the given format. Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format Providing the argument might help when libsox can not infer the format
...@@ -29,11 +61,14 @@ def info( ...@@ -29,11 +61,14 @@ def info(
Returns: Returns:
AudioMetaData: Metadata of the given audio. AudioMetaData: Metadata of the given audio.
""" """
# Cast to str in case type is `pathlib.Path` if not torch.jit.is_scripting():
filepath = str(filepath) return _info(filepath, format)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(), return AudioMetaData(
sinfo.get_bits_per_sample()) sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample())
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
......
...@@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) { ...@@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info", "get_info",
&torch::audio::get_info, &torch::audio::get_info,
"Gets information about an audio file"); "Gets information about an audio file");
m.def(
"get_info_fileobj",
&torchaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def( m.def(
"load_audio_fileobj", "load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj, &torchaudio::sox_io::load_audio_fileobj,
......
...@@ -36,7 +36,7 @@ int64_t SignalInfo::getBitsPerSample() const { ...@@ -36,7 +36,7 @@ int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample; return bits_per_sample;
} }
c10::intrusive_ptr<SignalInfo> get_info( c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path, const std::string& path,
c10::optional<std::string>& format) { c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read( SoxFormat sf(sox_open_read(
...@@ -149,6 +149,56 @@ void save_audio_file( ...@@ -149,6 +149,56 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format) {
// Prepare in-memory file object
// When libsox opens a file, it also reads the header.
// When opening a file there are two functions that might touch FILE* (and the
// underlying buffer).
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
//
// `auto_detect_format` function only requires 256 bytes, but format-dependant
// `startread` handler might require more data. In case of vorbis, the size of
// header is unbounded, but typically 4kB maximum.
//
// "The header size is unbounded, although for streaming a rule-of-thumb of
// 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this
// suggestion)."
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto capacity = 4096;
std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_file(sf, /*check_length=*/false);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample));
}
std::tuple<torch::Tensor, int64_t> load_audio_fileobj( std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj, py::object fileobj,
c10::optional<int64_t>& frame_offset, c10::optional<int64_t>& frame_offset,
......
...@@ -28,7 +28,7 @@ struct SignalInfo : torch::CustomClassHolder { ...@@ -28,7 +28,7 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t getBitsPerSample() const; int64_t getBitsPerSample() const;
}; };
c10::intrusive_ptr<SignalInfo> get_info( c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path, const std::string& path,
c10::optional<std::string>& format); c10::optional<std::string>& format);
...@@ -50,6 +50,10 @@ void save_audio_file( ...@@ -50,6 +50,10 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format);
std::tuple<torch::Tensor, int64_t> load_audio_fileobj( std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj, py::object fileobj,
c10::optional<int64_t>& frame_offset, c10::optional<int64_t>& frame_offset,
......
...@@ -47,7 +47,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -47,7 +47,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"get_bits_per_sample", "get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample); &torchaudio::sox_io::SignalInfo::getBitsPerSample);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info); m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def( m.def(
"torchaudio::sox_io_load_audio_file(" "torchaudio::sox_io_load_audio_file("
"str path," "str path,"
......
...@@ -317,5 +317,37 @@ sox_encodinginfo_t get_encodinginfo( ...@@ -317,5 +317,37 @@ sox_encodinginfo_t get_encodinginfo(
/*opposite_endian=*/sox_false}; /*opposite_endian=*/sox_false};
} }
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
std::cerr << "req: " << request << ", fetched: " << chunk_len << std::endl;
std::cerr << "buffer: " << (void*)buffer << std::endl;
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_utils } // namespace sox_utils
} // namespace torchaudio } // namespace torchaudio
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#include <sox.h> #include <sox.h>
#include <torch/script.h> #include <torch/script.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace torchaudio { namespace torchaudio {
namespace sox_utils { namespace sox_utils {
...@@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo( ...@@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo(
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression); c10::optional<double>& compression);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t read_fileobj(py::object* fileobj, uint64_t size, char* buffer);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_utils } // namespace sox_utils
} // namespace torchaudio } // namespace torchaudio
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment