Commit f4d94cab authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Add frame writing API to StreamWriter (#3244)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3244

Adds methods to `StreamWriter` that allow for passing in `AVFrame` instances rather than tensors.

Reviewed By: mthrok

Differential Revision: D44589256

fbshipit-source-id: f100e0d349708482b873a9a4bae1eaf5eb65301a
parent d69e8857
......@@ -726,7 +726,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
const c10::optional<std::string>& filter_desc,
bool disable_converter) {
// 1. Check the source format, rate and channels
TORCH_CHECK(
src_sample_rate > 0,
......@@ -736,7 +737,13 @@ EncodeProcess get_audio_encode_process(
src_num_channels > 0,
"The number of channels must be positive. Found: ",
src_num_channels);
const AVSampleFormat src_fmt = get_src_sample_fmt(format);
// Note that disable_converter = true indicates that the caller is looking to
// directly supply frames and bypass tensor conversion. Therefore, in this
// case, restrictions on the format to support tensor inputs do not apply, and
// so we directly get the format via FFmpeg.
const AVSampleFormat src_fmt = (disable_converter)
? av_get_sample_fmt(format.c_str())
: get_src_sample_fmt(format);
const auto src_ch_layout =
static_cast<uint64_t>(av_get_default_channel_layout(src_num_channels));
......@@ -791,7 +798,9 @@ EncodeProcess get_audio_encode_process(
// 7. Instantiate Converter
TensorConverter converter{
AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples};
(disable_converter) ? AVMEDIA_TYPE_UNKNOWN : AVMEDIA_TYPE_AUDIO,
src_frame,
src_frame->nb_samples};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
......@@ -830,7 +839,8 @@ EncodeProcess get_video_encode_process(
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
const c10::optional<std::string>& filter_desc,
bool disable_converter) {
// 1. Checkc the source format, rate and resolution
TORCH_CHECK(
std::isfinite(frame_rate) && frame_rate > 0,
......@@ -838,7 +848,13 @@ EncodeProcess get_video_encode_process(
frame_rate);
TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width);
TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height);
const AVPixelFormat src_fmt = get_src_pix_fmt(format);
// Note that disable_converter = true indicates that the caller is looking to
// directly supply frames and bypass tensor conversion. Therefore, in this
// case, restrictions on the format to support tensor inputs do not apply, and
// so we directly get the format via FFmpeg.
const AVPixelFormat src_fmt = (disable_converter)
? av_get_pix_fmt(format.c_str())
: get_src_pix_fmt(format);
const AVRational src_rate = av_d2q(frame_rate, 1 << 24);
// 2. Fetch codec from default or override
......@@ -914,7 +930,9 @@ EncodeProcess get_video_encode_process(
}();
// 7. Converter
TensorConverter converter{AVMEDIA_TYPE_VIDEO, src_frame};
TensorConverter converter{
(disable_converter) ? AVMEDIA_TYPE_UNKNOWN : AVMEDIA_TYPE_VIDEO,
src_frame};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
......
......@@ -28,10 +28,9 @@ class EncodeProcess {
void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void flush();
private:
void process_frame(AVFrame* src);
void flush();
};
EncodeProcess get_audio_encode_process(
......@@ -45,7 +44,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);
const c10::optional<std::string>& filter_desc,
bool disable_converter = false);
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
......@@ -61,6 +61,7 @@ EncodeProcess get_video_encode_process(
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);
const c10::optional<std::string>& filter_desc,
bool disable_converter = false);
}; // namespace torchaudio::io
......@@ -117,6 +117,72 @@ void StreamWriter::add_video_stream(
filter_desc));
}
void StreamWriter::add_audio_frame_stream(
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext,
sample_rate,
num_channels,
format,
encoder,
encoder_option,
encoder_format,
encoder_sample_rate,
encoder_num_channels,
codec_config,
filter_desc,
true));
}
void StreamWriter::add_video_frame_stream(
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<double>& encoder_frame_rate,
const c10::optional<int>& encoder_width,
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext,
frame_rate,
width,
height,
format,
encoder,
encoder_option,
encoder_format,
encoder_frame_rate,
encoder_width,
encoder_height,
hw_accel,
codec_config,
filter_desc,
true));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
av_dict_free(&pFormatContext->metadata);
for (auto const& [key, value] : metadata) {
......@@ -226,6 +292,17 @@ void StreamWriter::write_video_chunk(
processes[i].process(frames, pts);
}
void StreamWriter::write_frame(int i, AVFrame* frame) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ",
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process_frame(frame);
}
void StreamWriter::flush() {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
for (auto& p : processes) {
......
......@@ -160,6 +160,42 @@ class StreamWriter {
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// @cond
/// Add output audio frame stream.
/// Allows for writing frames rather than tensors via `write_frame`.
///
/// See `add_audio_stream` for more detail on input parameters.
void add_audio_frame_stream(
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<int>& encoder_sample_rate = c10::nullopt,
const c10::optional<int>& encoder_num_channels = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// Add output video frame stream.
/// Allows for writing frames rather than tensors via `write_frame`.
///
/// See `add_video_stream` for more detail on input parameters.
void add_video_frame_stream(
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<double>& encoder_frame_rate = c10::nullopt,
const c10::optional<int>& encoder_width = c10::nullopt,
const c10::optional<int>& encoder_height = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// @endcond
/// Set file-level metadata
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
......@@ -215,6 +251,12 @@ class StreamWriter {
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = c10::nullopt);
/// @cond
/// Write frame to stream.
/// @param i Stream index.
/// @param frame Frame to write.
void write_frame(int i, AVFrame* frame);
/// @endcond
/// Flush the frames from encoders and write the frames to the destination.
void flush();
};
......
......@@ -343,6 +343,26 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
}
}
////////////////////////////////////////////////////////////////////////////////
// Unknown (for supporting frame writing)
////////////////////////////////////////////////////////////////////////////////
std::pair<InitFunc, ConvertFunc> get_frame_func() {
InitFunc init_func = [](const torch::Tensor& tensor,
AVFrame* buffer) -> torch::Tensor {
TORCH_CHECK(
false,
"This shouldn't have been called. "
"If you intended to write frames, please select a stream that supports doing so.");
};
ConvertFunc convert_func = [](const torch::Tensor& tensor, AVFrame* buffer) {
TORCH_CHECK(
false,
"This shouldn't have been called. "
"If you intended to write frames, please select a stream that supports doing so.");
};
return {init_func, convert_func};
}
} // namespace
////////////////////////////////////////////////////////////////////////////////
......@@ -358,6 +378,9 @@ TensorConverter::TensorConverter(AVMediaType type, AVFrame* buf, int buf_size)
case AVMEDIA_TYPE_VIDEO:
std::tie(init_func, convert_func) = get_video_func(buffer);
break;
case AVMEDIA_TYPE_UNKNOWN:
std::tie(init_func, convert_func) = get_frame_func();
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Unsupported media type: ", av_get_media_type_string(type));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment