Commit f4d94cab authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Add frame writing API to StreamWriter (#3244)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3244

Adds methods to `StreamWriter` that allow for passing in `AVFrame` instances rather than tensors.

Reviewed By: mthrok

Differential Revision: D44589256

fbshipit-source-id: f100e0d349708482b873a9a4bae1eaf5eb65301a
parent d69e8857
...@@ -726,7 +726,8 @@ EncodeProcess get_audio_encode_process( ...@@ -726,7 +726,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<int>& encoder_sample_rate, const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels, const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config, const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) { const c10::optional<std::string>& filter_desc,
bool disable_converter) {
// 1. Check the source format, rate and channels // 1. Check the source format, rate and channels
TORCH_CHECK( TORCH_CHECK(
src_sample_rate > 0, src_sample_rate > 0,
...@@ -736,7 +737,13 @@ EncodeProcess get_audio_encode_process( ...@@ -736,7 +737,13 @@ EncodeProcess get_audio_encode_process(
src_num_channels > 0, src_num_channels > 0,
"The number of channels must be positive. Found: ", "The number of channels must be positive. Found: ",
src_num_channels); src_num_channels);
const AVSampleFormat src_fmt = get_src_sample_fmt(format); // Note that disable_converter = true indicates that the caller is looking to
// directly supply frames and bypass tensor conversion. Therefore, in this
// case, restrictions on the format to support tensor inputs do not apply, and
// so we directly get the format via FFmpeg.
const AVSampleFormat src_fmt = (disable_converter)
? av_get_sample_fmt(format.c_str())
: get_src_sample_fmt(format);
const auto src_ch_layout = const auto src_ch_layout =
static_cast<uint64_t>(av_get_default_channel_layout(src_num_channels)); static_cast<uint64_t>(av_get_default_channel_layout(src_num_channels));
...@@ -791,7 +798,9 @@ EncodeProcess get_audio_encode_process( ...@@ -791,7 +798,9 @@ EncodeProcess get_audio_encode_process(
// 7. Instantiate Converter // 7. Instantiate Converter
TensorConverter converter{ TensorConverter converter{
AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples}; (disable_converter) ? AVMEDIA_TYPE_UNKNOWN : AVMEDIA_TYPE_AUDIO,
src_frame,
src_frame->nb_samples};
// 8. encoder // 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream. // Note: get_stream modifies AVFormatContext and adds new stream.
...@@ -830,7 +839,8 @@ EncodeProcess get_video_encode_process( ...@@ -830,7 +839,8 @@ EncodeProcess get_video_encode_process(
const c10::optional<int>& encoder_height, const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config, const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) { const c10::optional<std::string>& filter_desc,
bool disable_converter) {
// 1. Checkc the source format, rate and resolution // 1. Checkc the source format, rate and resolution
TORCH_CHECK( TORCH_CHECK(
std::isfinite(frame_rate) && frame_rate > 0, std::isfinite(frame_rate) && frame_rate > 0,
...@@ -838,7 +848,13 @@ EncodeProcess get_video_encode_process( ...@@ -838,7 +848,13 @@ EncodeProcess get_video_encode_process(
frame_rate); frame_rate);
TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width); TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width);
TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height); TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height);
const AVPixelFormat src_fmt = get_src_pix_fmt(format); // Note that disable_converter = true indicates that the caller is looking to
// directly supply frames and bypass tensor conversion. Therefore, in this
// case, restrictions on the format to support tensor inputs do not apply, and
// so we directly get the format via FFmpeg.
const AVPixelFormat src_fmt = (disable_converter)
? av_get_pix_fmt(format.c_str())
: get_src_pix_fmt(format);
const AVRational src_rate = av_d2q(frame_rate, 1 << 24); const AVRational src_rate = av_d2q(frame_rate, 1 << 24);
// 2. Fetch codec from default or override // 2. Fetch codec from default or override
...@@ -914,7 +930,9 @@ EncodeProcess get_video_encode_process( ...@@ -914,7 +930,9 @@ EncodeProcess get_video_encode_process(
}(); }();
// 7. Converter // 7. Converter
TensorConverter converter{AVMEDIA_TYPE_VIDEO, src_frame}; TensorConverter converter{
(disable_converter) ? AVMEDIA_TYPE_UNKNOWN : AVMEDIA_TYPE_VIDEO,
src_frame};
// 8. encoder // 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream. // Note: get_stream modifies AVFormatContext and adds new stream.
......
...@@ -28,10 +28,9 @@ class EncodeProcess { ...@@ -28,10 +28,9 @@ class EncodeProcess {
void process(const torch::Tensor& tensor, const c10::optional<double>& pts); void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void flush();
private:
void process_frame(AVFrame* src); void process_frame(AVFrame* src);
void flush();
}; };
EncodeProcess get_audio_encode_process( EncodeProcess get_audio_encode_process(
...@@ -45,7 +44,8 @@ EncodeProcess get_audio_encode_process( ...@@ -45,7 +44,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<int>& encoder_sample_rate, const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels, const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config, const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc); const c10::optional<std::string>& filter_desc,
bool disable_converter = false);
EncodeProcess get_video_encode_process( EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
...@@ -61,6 +61,7 @@ EncodeProcess get_video_encode_process( ...@@ -61,6 +61,7 @@ EncodeProcess get_video_encode_process(
const c10::optional<int>& encoder_height, const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config, const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc); const c10::optional<std::string>& filter_desc,
bool disable_converter = false);
}; // namespace torchaudio::io }; // namespace torchaudio::io
...@@ -117,6 +117,72 @@ void StreamWriter::add_video_stream( ...@@ -117,6 +117,72 @@ void StreamWriter::add_video_stream(
filter_desc)); filter_desc));
} }
void StreamWriter::add_audio_frame_stream(
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext,
sample_rate,
num_channels,
format,
encoder,
encoder_option,
encoder_format,
encoder_sample_rate,
encoder_num_channels,
codec_config,
filter_desc,
true));
}
void StreamWriter::add_video_frame_stream(
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<double>& encoder_frame_rate,
const c10::optional<int>& encoder_width,
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext,
frame_rate,
width,
height,
format,
encoder,
encoder_option,
encoder_format,
encoder_frame_rate,
encoder_width,
encoder_height,
hw_accel,
codec_config,
filter_desc,
true));
}
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
av_dict_free(&pFormatContext->metadata); av_dict_free(&pFormatContext->metadata);
for (auto const& [key, value] : metadata) { for (auto const& [key, value] : metadata) {
...@@ -226,6 +292,17 @@ void StreamWriter::write_video_chunk( ...@@ -226,6 +292,17 @@ void StreamWriter::write_video_chunk(
processes[i].process(frames, pts); processes[i].process(frames, pts);
} }
void StreamWriter::write_frame(int i, AVFrame* frame) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ",
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process_frame(frame);
}
void StreamWriter::flush() { void StreamWriter::flush() {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
for (auto& p : processes) { for (auto& p : processes) {
......
...@@ -160,6 +160,42 @@ class StreamWriter { ...@@ -160,6 +160,42 @@ class StreamWriter {
const c10::optional<std::string>& hw_accel = c10::nullopt, const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt, const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt); const c10::optional<std::string>& filter_desc = c10::nullopt);
/// @cond
/// Add output audio frame stream.
/// Allows for writing frames rather than tensors via `write_frame`.
///
/// See `add_audio_stream` for more detail on input parameters.
void add_audio_frame_stream(
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<int>& encoder_sample_rate = c10::nullopt,
const c10::optional<int>& encoder_num_channels = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// Add output video frame stream.
/// Allows for writing frames rather than tensors via `write_frame`.
///
/// See `add_video_stream` for more detail on input parameters.
void add_video_frame_stream(
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<double>& encoder_frame_rate = c10::nullopt,
const c10::optional<int>& encoder_width = c10::nullopt,
const c10::optional<int>& encoder_height = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// @endcond
/// Set file-level metadata /// Set file-level metadata
/// @param metadata metadata. /// @param metadata metadata.
void set_metadata(const OptionDict& metadata); void set_metadata(const OptionDict& metadata);
...@@ -215,6 +251,12 @@ class StreamWriter { ...@@ -215,6 +251,12 @@ class StreamWriter {
int i, int i,
const torch::Tensor& frames, const torch::Tensor& frames,
const c10::optional<double>& pts = c10::nullopt); const c10::optional<double>& pts = c10::nullopt);
/// @cond
/// Write frame to stream.
/// @param i Stream index.
/// @param frame Frame to write.
void write_frame(int i, AVFrame* frame);
/// @endcond
/// Flush the frames from encoders and write the frames to the destination. /// Flush the frames from encoders and write the frames to the destination.
void flush(); void flush();
}; };
......
...@@ -343,6 +343,26 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) { ...@@ -343,6 +343,26 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
} }
} }
////////////////////////////////////////////////////////////////////////////////
// Unknown (for supporting frame writing)
////////////////////////////////////////////////////////////////////////////////
std::pair<InitFunc, ConvertFunc> get_frame_func() {
InitFunc init_func = [](const torch::Tensor& tensor,
AVFrame* buffer) -> torch::Tensor {
TORCH_CHECK(
false,
"This shouldn't have been called. "
"If you intended to write frames, please select a stream that supports doing so.");
};
ConvertFunc convert_func = [](const torch::Tensor& tensor, AVFrame* buffer) {
TORCH_CHECK(
false,
"This shouldn't have been called. "
"If you intended to write frames, please select a stream that supports doing so.");
};
return {init_func, convert_func};
}
} // namespace } // namespace
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -358,6 +378,9 @@ TensorConverter::TensorConverter(AVMediaType type, AVFrame* buf, int buf_size) ...@@ -358,6 +378,9 @@ TensorConverter::TensorConverter(AVMediaType type, AVFrame* buf, int buf_size)
case AVMEDIA_TYPE_VIDEO: case AVMEDIA_TYPE_VIDEO:
std::tie(init_func, convert_func) = get_video_func(buffer); std::tie(init_func, convert_func) = get_video_func(buffer);
break; break;
case AVMEDIA_TYPE_UNKNOWN:
std::tie(init_func, convert_func) = get_frame_func();
break;
default: default:
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
false, "Unsupported media type: ", av_get_media_type_string(type)); false, "Unsupported media type: ", av_get_media_type_string(type));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment