Commit c5a43372 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Support in-memory decoding via Tensor wrapper in StreamReader (#2694)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2694

This commit adds Tensor type as input to `StreamReader`.
The Tensor is interpreted as byte string buffer.

Reviewed By: hwangjeff

Differential Revision: D39467630

fbshipit-source-id: 6369eed5e16fbb657568bf6bb80d703483d72f8e
parent 30c7077b
...@@ -28,9 +28,9 @@ if is_ffmpeg_available(): ...@@ -28,9 +28,9 @@ if is_ffmpeg_available():
################################################################################ ################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj # Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class( _media_source = parameterized_class(
("test_fileobj",), ("test_type",),
[(False,), (True,)], [("str",), ("fileobj",), ("tensor",)],
class_name_func=lambda cls, _, params: f'{cls.__name__}{"_fileobj" if params["test_fileobj"] else "_path"}', class_name_func=lambda cls, _, params: f'{cls.__name__}_{params["test_type"]}',
) )
...@@ -40,16 +40,24 @@ class _MediaSourceMixin: ...@@ -40,16 +40,24 @@ class _MediaSourceMixin:
self.src = None self.src = None
def get_src(self, path): def get_src(self, path):
if not self.test_fileobj:
return path
if self.src is not None: if self.src is not None:
raise ValueError("get_video_asset can be called only once.") raise ValueError("get_src can be called only once.")
self.src = open(path, "rb") if self.test_type == "str":
self.src = path
elif self.test_type == "fileobj":
self.src = open(path, "rb")
elif self.test_type == "tensor":
with open(path, "rb") as fileobj:
data = fileobj.read()
self.src = torch.frombuffer(data, dtype=torch.uint8)
print(self.src.data_ptr())
print(len(data))
print(self.src.shape)
return self.src return self.src
def tearDown(self): def tearDown(self):
if self.src is not None: if self.test_type == "fileobj" and self.src is not None:
self.src.close() self.src.close()
super().tearDown() super().tearDown()
...@@ -482,12 +490,12 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -482,12 +490,12 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
# provide the matching dtype # provide the matching dtype
self._test_wav(src, original, fmt=fmt) self._test_wav(src, original, fmt=fmt)
# use the internal dtype ffmpeg picks # use the internal dtype ffmpeg picks
if self.test_fileobj: if self.test_type == "fileobj":
src.seek(0) src.seek(0)
self._test_wav(src, original, fmt=None) self._test_wav(src, original, fmt=None)
# convert to float32 # convert to float32
expected = _to_fltp(original) expected = _to_fltp(original)
if self.test_fileobj: if self.test_type == "fileobj":
src.seek(0) src.seek(0)
self._test_wav(src, expected, fmt="fltp") self._test_wav(src, expected, fmt="fltp")
...@@ -517,7 +525,7 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -517,7 +525,7 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
for t in range(10, 20): for t in range(10, 20):
expected = original[t:, :] expected = original[t:, :]
if self.test_fileobj: if self.test_type == "fileobj":
src.seek(0) src.seek(0)
s = StreamReader(src) s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1) s.add_audio_stream(frames_per_chunk=-1)
......
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_tensor_binding.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
TensorIndexer* tensorobj = static_cast<TensorIndexer*>(opaque);
int num_read = FFMIN(tensorobj->numel - tensorobj->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
uint8_t* head = const_cast<uint8_t*>(tensorobj->data) + tensorobj->index;
memcpy(buf, head, num_read);
tensorobj->index += num_read;
return num_read;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
TensorIndexer* tensorobj = static_cast<TensorIndexer*>(opaque);
if (whence == AVSEEK_SIZE) {
return static_cast<int64_t>(tensorobj->numel);
}
if (whence == SEEK_SET) {
tensorobj->index = offset;
} else if (whence == SEEK_CUR) {
tensorobj->index += offset;
} else if (whence == SEEK_END) {
tensorobj->index = tensorobj->numel + offset;
} else {
TORCH_CHECK(false, "[INTERNAL ERROR] Unexpected whence value: ", whence);
}
return static_cast<int64_t>(tensorobj->index);
}
AVIOContext* get_io_context(TensorIndexer* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
&seek_function);
if (!av_io_ctx) {
av_freep(&buffer);
TORCH_CHECK(av_io_ctx, "Failed to initialize AVIOContext.");
}
return av_io_ctx;
}
std::string get_id(const torch::Tensor& src) {
std::stringstream ss;
ss << "Tensor <" << static_cast<const void*>(src.data_ptr<uint8_t>()) << ">";
return ss.str();
}
} // namespace
TensorIndexer::TensorIndexer(const torch::Tensor& src, int buffer_size)
: src(src),
data([&]() -> uint8_t* {
TORCH_CHECK(
src.is_contiguous(), "The input Tensor must be contiguous.");
TORCH_CHECK(
src.dtype() == torch::kUInt8,
"The input Tensor must be uint8 type. Found: ",
src.dtype());
TORCH_CHECK(
src.device().type() == c10::DeviceType::CPU,
"The input Tensor must be on CPU. Found: ",
src.device().str());
TORCH_CHECK(
src.dim() == 1, "The input Tensor must be 1D. Found: ", src.dim());
return src.data_ptr<uint8_t>();
}()),
numel(src.numel()),
pAVIO(get_io_context(this, buffer_size)) {}
StreamReaderTensorBinding::StreamReaderTensorBinding(
const torch::Tensor& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
int buffer_size)
: TensorIndexer(src, buffer_size),
StreamReaderBinding(
get_input_format_context(get_id(src), device, option, pAVIO)) {}
namespace {
c10::intrusive_ptr<StreamReaderTensorBinding> init(
const torch::Tensor& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
int64_t buffer_size) {
return c10::make_intrusive<StreamReaderTensorBinding>(
src, device, option, static_cast<int>(buffer_size));
}
using S = const c10::intrusive_ptr<StreamReaderTensorBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<StreamReaderTensorBinding>("ffmpeg_StreamReaderTensor")
.def(torch::init<>(init))
.def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); })
.def("get_metadata", [](S self) { return self->get_metadata(); })
.def(
"get_src_stream_info",
[](S s, int64_t i) { return s->get_src_stream_info(i); })
.def(
"get_out_stream_info",
[](S s, int64_t i) { return s->get_out_stream_info(i); })
.def(
"find_best_audio_stream",
[](S s) { return s->find_best_audio_stream(); })
.def(
"find_best_video_stream",
[](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t) { return s->seek(t); })
.def(
"add_audio_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
decoder_option);
})
.def(
"add_video_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
decoder_option,
hw_accel);
})
.def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
.def(
"process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->process_packet(timeout, backoff);
})
.def("process_all_packets", [](S s) { s->process_all_packets(); })
.def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); })
.def("pop_chunks", [](S s) { return s->pop_chunks(); });
}
} // namespace
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
// Helper structure to keep track of until where the decoding has happened
struct TensorIndexer {
torch::Tensor src;
size_t index = 0;
const uint8_t* data;
const size_t numel;
AVIOContextPtr pAVIO;
TensorIndexer(const torch::Tensor& src, int buffer_size);
};
// Structure to implement wrapper API around StreamReader, which is more
// suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderTensorBinding : protected TensorIndexer,
public StreamReaderBinding {
StreamReaderTensorBinding(
const torch::Tensor& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
int buffer_size);
};
} // namespace ffmpeg
} // namespace torchaudio
...@@ -283,7 +283,7 @@ class StreamReader: ...@@ -283,7 +283,7 @@ class StreamReader:
For the detailed usage of this class, please refer to the tutorial. For the detailed usage of this class, please refer to the tutorial.
Args: Args:
src (str or file-like object): The media source. src (str, file-like object or Tensor): The media source.
If string-type, it must be a resource indicator that FFmpeg can If string-type, it must be a resource indicator that FFmpeg can
handle. This includes a file path, URL, device identifier or handle. This includes a file path, URL, device identifier or
filter expression. The supported value depends on the FFmpeg found filter expression. The supported value depends on the FFmpeg found
...@@ -296,6 +296,9 @@ class StreamReader: ...@@ -296,6 +296,9 @@ class StreamReader:
of codec detection. The signagure of `seek` method must be of codec detection. The signagure of `seek` method must be
`seek(offset: int, whence: int) -> int`. `seek(offset: int, whence: int) -> int`.
If Tensor, it is interpreted as byte buffer.
It must be one-dimensional, of type ``torch.uint8``.
Please refer to the following for the expected signature and behavior Please refer to the following for the expected signature and behavior
of `read` and `seek` method. of `read` and `seek` method.
...@@ -350,10 +353,12 @@ class StreamReader: ...@@ -350,10 +353,12 @@ class StreamReader:
): ):
if isinstance(src, str): if isinstance(src, str):
self._be = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, option) self._be = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, option)
elif isinstance(src, torch.Tensor):
self._be = torch.classes.torchaudio.ffmpeg_StreamReaderTensor(src, format, option, buffer_size)
elif hasattr(src, "read"): elif hasattr(src, "read"):
self._be = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size) self._be = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size)
else: else:
raise ValueError("`src` must be either string or file-like object.") raise ValueError("`src` must be either string, Tensor or file-like object.")
i = self._be.find_best_audio_stream() i = self._be.find_best_audio_stream()
self._default_audio_stream = None if i < 0 else i self._default_audio_stream = None if i < 0 else i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment