Commit 51767917 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor StreamWriterCustomIO (#3319)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3319

* Merge the source with StreamWriter
* Add docstrings
* Move CustomIO to detail::CustomOutput to prepare for adding CustomInput

Reviewed By: hwangjeff

Differential Revision: D45481807

fbshipit-source-id: 4a9ac8a57acda47b126f8ae18e607b72919f9988
parent 51cc1cbf
......@@ -21,7 +21,6 @@ set(
stream_writer/encoder.cpp
stream_writer/packet_writer.cpp
stream_writer/stream_writer.cpp
stream_writer/stream_writer_custom_io.cpp
stream_writer/tensor_converter.cpp
compat.cpp
)
......
......@@ -343,5 +343,46 @@ void StreamWriter::flush() {
int StreamWriter::num_output_streams() {
return static_cast<int>(processes.size() + packet_writers.size());
}
////////////////////////////////////////////////////////////////////////////////
// StreamWriterCustomIO
////////////////////////////////////////////////////////////////////////////////
namespace detail {
namespace {
AVIOContext* get_io_context(
void* opaque,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence)) {
unsigned char* buffer = static_cast<unsigned char*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
AVIOContext* io_ctx = avio_alloc_context(
buffer, buffer_size, 1, opaque, nullptr, write_packet, seek);
if (!io_ctx) {
av_freep(&buffer);
TORCH_CHECK(false, "Failed to allocate AVIOContext.");
}
return io_ctx;
}
} // namespace
CustomOutput::CustomOutput(
void* opaque,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence))
: io_ctx(get_io_context(opaque, buffer_size, write_packet, seek)) {}
} // namespace detail
StreamWriterCustomIO::StreamWriterCustomIO(
void* opaque,
const c10::optional<std::string>& format,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence))
: CustomOutput(opaque, buffer_size, write_packet, seek),
StreamWriter(io_ctx, format) {}
} // namespace io
} // namespace torchaudio
......@@ -10,6 +10,10 @@
namespace torchaudio {
namespace io {
////////////////////////////////////////////////////////////////////////////////
// StreamWriter
////////////////////////////////////////////////////////////////////////////////
///
/// Encode and write audio/video streams chunk by chunk
///
......@@ -287,5 +291,43 @@ class StreamWriter {
int num_output_streams();
};
////////////////////////////////////////////////////////////////////////////////
// StreamWriterCustomIO
////////////////////////////////////////////////////////////////////////////////
/// @cond
namespace detail {
struct CustomOutput {
AVIOContextPtr io_ctx;
CustomOutput(
void* opaque,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence));
};
} // namespace detail
/// @endcond
/// Construct StreamWriter with custom write and seek functions.
///
/// @param opaque Custom data used by write_packet and seek functions.
/// @param format Specify output format.
/// @param buffer_size The size of the intermediate buffer, which FFmpeg uses to
/// pass data to write_packet function.
/// @param write_packet Custom write function that is called from FFmpeg to
/// actually write data to the custom destination.
/// @param seek Optional seek function that is used to seek the destination.
struct StreamWriterCustomIO : private detail::CustomOutput,
public StreamWriter {
StreamWriterCustomIO(
void* opaque,
const c10::optional<std::string>& format,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence) = nullptr);
};
} // namespace io
} // namespace torchaudio
#include "torchaudio/csrc/ffmpeg/stream_writer/stream_writer_custom_io.h"
namespace torchaudio::io {
AVIOContext* get_io_context(
void* opaque,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence)) {
unsigned char* buffer = static_cast<unsigned char*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
AVIOContext* io_ctx = avio_alloc_context(
buffer, buffer_size, 1, opaque, nullptr, write_packet, seek);
if (!io_ctx) {
av_freep(&buffer);
TORCH_CHECK(false, "Failed to allocate AVIOContext.");
}
return io_ctx;
}
CustomIO::CustomIO(
void* opaque,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence))
: io_ctx(get_io_context(opaque, buffer_size, write_packet, seek)) {}
StreamWriterCustomIO::StreamWriterCustomIO(
void* opaque,
const c10::optional<std::string>& format,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence))
: CustomIO(opaque, buffer_size, write_packet, seek),
StreamWriter(io_ctx, format) {}
} // namespace torchaudio::io
#pragma once
#include "torchaudio/csrc/ffmpeg/ffmpeg.h"
#include "torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h"
namespace torchaudio::io {
struct CustomIO {
AVIOContextPtr io_ctx;
CustomIO(
void* opaque,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence));
};
struct StreamWriterCustomIO : private CustomIO, public StreamWriter {
StreamWriterCustomIO(
void* opaque,
const c10::optional<std::string>& format,
int buffer_size,
int (*write_packet)(void* opaque, uint8_t* buf, int buf_size),
int64_t (*seek)(void* opaque, int64_t offset, int whence));
};
} // namespace torchaudio::io
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment