Commit d3c9295c authored by mthrok's avatar mthrok Committed by Facebook GitHub Bot
Browse files

Remove Tensor binding from StreamReader (#3093)

Summary:
Remove the Tensor input support from StreamReader

Follow up of https://github.com/pytorch/audio/pull/3086

Pull Request resolved: https://github.com/pytorch/audio/pull/3093

Reviewed By: xiaohui-zhang

Differential Revision: D43526066

Pulled By: mthrok

fbshipit-source-id: 57ba4866c413649173e1c2c3b23ba7de3231b7bc
parent a26c2f27
......@@ -65,7 +65,7 @@ class ChunkTensorTest(TorchaudioTestCase):
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class(
("test_type",),
[("str",), ("fileobj",), ("tensor",)],
[("str",), ("fileobj",)],
class_name_func=lambda cls, _, params: f'{cls.__name__}_{params["test_type"]}',
)
......@@ -83,10 +83,6 @@ class _MediaSourceMixin:
self.src = path
elif self.test_type == "fileobj":
self.src = open(path, "rb")
elif self.test_type == "tensor":
with open(path, "rb") as fileobj:
data = fileobj.read()
self.src = torch.frombuffer(data, dtype=torch.uint8)
return self.src
def tearDown(self):
......
......@@ -18,7 +18,6 @@ set(
stream_reader/stream_reader.cpp
stream_reader/stream_reader_wrapper.cpp
stream_reader/stream_reader_binding.cpp
stream_reader/stream_reader_tensor_binding.cpp
stream_writer/stream_writer.cpp
stream_writer/stream_writer_binding.cpp
utils.cpp
......
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h>
namespace torchaudio {
namespace io {
namespace {
//////////////////////////////////////////////////////////////////////////////
// TensorIndexer
//////////////////////////////////////////////////////////////////////////////
// Helper structure to keep track of until where the decoding has happened
struct TensorIndexer {
torch::Tensor src;
size_t index = 0;
const uint8_t* data;
const size_t numel;
AVIOContextPtr pAVIO;
TensorIndexer(const torch::Tensor& src, int buffer_size);
};
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
TensorIndexer* tensorobj = static_cast<TensorIndexer*>(opaque);
int num_read = FFMIN(tensorobj->numel - tensorobj->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
uint8_t* head = const_cast<uint8_t*>(tensorobj->data) + tensorobj->index;
memcpy(buf, head, num_read);
tensorobj->index += num_read;
return num_read;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
TensorIndexer* tensorobj = static_cast<TensorIndexer*>(opaque);
if (whence == AVSEEK_SIZE) {
return static_cast<int64_t>(tensorobj->numel);
}
if (whence == SEEK_SET) {
tensorobj->index = offset;
} else if (whence == SEEK_CUR) {
tensorobj->index += offset;
} else if (whence == SEEK_END) {
tensorobj->index = tensorobj->numel + offset;
} else {
TORCH_CHECK(false, "[INTERNAL ERROR] Unexpected whence value: ", whence);
}
return static_cast<int64_t>(tensorobj->index);
}
AVIOContext* get_io_context(TensorIndexer* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
&seek_function);
if (!av_io_ctx) {
av_freep(&buffer);
TORCH_CHECK(av_io_ctx, "Failed to initialize AVIOContext.");
}
return av_io_ctx;
}
TensorIndexer::TensorIndexer(const torch::Tensor& src, int buffer_size)
: src(src),
data([&]() -> uint8_t* {
TORCH_CHECK(
src.is_contiguous(), "The input Tensor must be contiguous.");
TORCH_CHECK(
src.dtype() == torch::kUInt8,
"The input Tensor must be uint8 type. Found: ",
src.dtype());
TORCH_CHECK(
src.device().type() == c10::DeviceType::CPU,
"The input Tensor must be on CPU. Found: ",
src.device().str());
TORCH_CHECK(
src.dim() == 1, "The input Tensor must be 1D. Found: ", src.dim());
return src.data_ptr<uint8_t>();
}()),
numel(src.numel()),
pAVIO(get_io_context(this, buffer_size)) {}
//////////////////////////////////////////////////////////////////////////////
// StreamReaderTensorBinding
//////////////////////////////////////////////////////////////////////////////
// Structure to implement wrapper API around StreamReader and input Tensor
struct StreamReaderTensorBinding : protected TensorIndexer,
public StreamReaderBinding {
StreamReaderTensorBinding(
const torch::Tensor& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
int buffer_size);
};
StreamReaderTensorBinding::StreamReaderTensorBinding(
const torch::Tensor& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
int buffer_size)
: TensorIndexer(src, buffer_size),
StreamReaderBinding(pAVIO, format, option) {}
using S = const c10::intrusive_ptr<StreamReaderTensorBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<StreamReaderTensorBinding>("ffmpeg_StreamReaderTensor")
.def(torch::init<>([](const torch::Tensor& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
int64_t buffer_size) {
TORCH_WARN_ONCE(
"Using Tensor as byte string buffer is deprecated, and will be removed in 2.1, please pass the data using I/O object.")
return c10::make_intrusive<StreamReaderTensorBinding>(
src, device, option, static_cast<int>(buffer_size));
}))
.def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); })
.def("get_metadata", [](S self) { return self->get_metadata(); })
.def(
"get_src_stream_info",
[](S s, int64_t i) { return s->get_src_stream_info(i); })
.def(
"get_out_stream_info",
[](S s, int64_t i) { return s->get_out_stream_info(i); })
.def(
"find_best_audio_stream",
[](S s) { return s->find_best_audio_stream(); })
.def(
"find_best_video_stream",
[](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t, int64_t mode) { return s->seek(t, mode); })
.def(
"add_audio_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
decoder_option);
})
.def(
"add_video_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
decoder_option,
hw_accel);
})
.def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
.def(
"process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->process_packet(timeout, backoff);
})
.def("process_all_packets", [](S s) { s->process_all_packets(); })
.def(
"fill_buffer",
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->fill_buffer(timeout, backoff);
})
.def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); })
.def("pop_chunks", [](S s) { return s->pop_chunks(); });
}
} // namespace
} // namespace io
} // namespace torchaudio
......@@ -388,7 +388,7 @@ class StreamReader:
For the detailed usage of this class, please refer to the tutorial.
Args:
src (str, file-like object or Tensor): The media source.
src (str, file-like object): The media source.
If string-type, it must be a resource indicator that FFmpeg can
handle. This includes a file path, URL, device identifier or
filter expression. The supported value depends on the FFmpeg found
......@@ -401,9 +401,6 @@ class StreamReader:
of codec detection. The signagure of `seek` method must be
`seek(offset: int, whence: int) -> int`.
If Tensor, it is interpreted as byte buffer.
It must be one-dimensional, of type ``torch.uint8``.
Please refer to the following for the expected signature and behavior
of `read` and `seek` method.
......@@ -465,8 +462,6 @@ class StreamReader:
torch._C._log_api_usage_once("torchaudio.io.StreamReader")
if isinstance(src, str):
self._be = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, option)
elif isinstance(src, torch.Tensor):
self._be = torch.classes.torchaudio.ffmpeg_StreamReaderTensor(src, format, option, buffer_size)
elif hasattr(src, "read"):
self._be = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment