Commit 49287210 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Cleaning up private methods (#3030)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3030

A part of StreamWriter refactoring

(Note: this ignores all push blocking failures!)

Reviewed By: hwangjeff

Differential Revision: D42905959

fbshipit-source-id: ba8add3ce549c70c3775640840e41ace06b0ef65
parent f663cb28
......@@ -666,36 +666,26 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) {
av_get_media_type_string(type));
}
void StreamWriter::process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVCodecContextPtr& c,
AVStream* st) {
int ret = filter->add_frame(src_frame);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encode_frame(nullptr, c, st);
}
break;
}
if (ret >= 0) {
encode_frame(dst_frame, c, st);
}
av_frame_unref(dst_frame);
}
}
namespace {
void StreamWriter::encode_frame(
///
/// Encode the given AVFrame data
///
/// @param frame Frame data to encode
/// @param format Output format context
/// @param stream Output stream in the output format context
/// @param codec Encoding context
/// @param packet Temporaly packet used during encoding.
void encode_frame(
AVFrame* frame,
AVCodecContextPtr& c,
AVStream* st) {
int ret = avcodec_send_frame(c, frame);
AVFormatContext* format,
AVStream* stream,
AVCodecContext* codec,
AVPacket* packet) {
int ret = avcodec_send_frame(codec, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) {
ret = avcodec_receive_packet(c, pkt);
ret = avcodec_receive_packet(codec, packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
// Note:
......@@ -710,7 +700,7 @@ void StreamWriter::encode_frame(
// An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter
ret = av_interleaved_write_frame(pFormatContext, nullptr);
ret = av_interleaved_write_frame(format, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
}
......@@ -725,20 +715,43 @@ void StreamWriter::encode_frame(
// https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish.
if (pkt->duration == 0 && c->codec_type == AVMEDIA_TYPE_VIDEO) {
if (packet->duration == 0 && codec->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow.
pkt->duration = 1;
packet->duration = 1;
}
av_packet_rescale_ts(pkt, c->time_base, st->time_base);
pkt->stream_index = st->index;
av_packet_rescale_ts(packet, codec->time_base, stream->time_base);
packet->stream_index = stream->index;
ret = av_interleaved_write_frame(pFormatContext, pkt);
ret = av_interleaved_write_frame(format, packet);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
}
}
namespace {
void process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVFormatContext* format,
AVStream* stream,
AVCodecContextPtr& codec,
AVPacket* packet) {
int ret = filter->add_frame(src_frame);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encode_frame(nullptr, format, stream, codec, packet);
}
break;
}
if (ret >= 0) {
encode_frame(dst_frame, format, stream, codec, packet);
}
av_frame_unref(dst_frame);
}
}
void validate_audio_input(
enum AVSampleFormat fmt,
AVCodecContext* ctx,
......@@ -812,6 +825,7 @@ void validate_video_input(
") (NCHW format). Found ",
t.sizes());
}
} // namespace
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
......@@ -853,9 +867,16 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
if (os.filter) {
process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream);
os.src_frame,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else {
encode_frame(os.src_frame, os.codec_ctx, os.stream);
encode_frame(
os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
}
}
});
......@@ -940,7 +961,7 @@ void StreamWriter::write_interlaced_video_cuda(
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
encode_frame(os.src_frame, os.codec_ctx, os.stream);
encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
}
}
......@@ -971,7 +992,7 @@ void StreamWriter::write_planar_video_cuda(
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
encode_frame(os.src_frame, os.codec_ctx, os.stream);
encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
}
}
#endif
......@@ -1020,9 +1041,15 @@ void StreamWriter::write_interlaced_video(
if (os.filter) {
process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream);
os.src_frame,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else {
encode_frame(os.src_frame, os.codec_ctx, os.stream);
encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
}
}
}
......@@ -1080,9 +1107,15 @@ void StreamWriter::write_planar_video(
if (os.filter) {
process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream);
os.src_frame,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else {
encode_frame(os.src_frame, os.codec_ctx, os.stream);
encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
}
}
}
......@@ -1095,9 +1128,16 @@ void StreamWriter::flush() {
void StreamWriter::flush_stream(OutputStream& os) {
if (os.filter) {
process_frame(nullptr, os.filter, os.dst_frame, os.codec_ctx, os.stream);
process_frame(
nullptr,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else {
encode_frame(nullptr, os.codec_ctx, os.stream);
encode_frame(nullptr, pFormatContext, os.stream, os.codec_ctx, pkt);
}
}
} // namespace io
......
......@@ -210,13 +210,6 @@ class StreamWriter {
const torch::Tensor& chunk,
bool pad_extra = true);
#endif
void process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVCodecContextPtr& c,
AVStream* st);
void encode_frame(AVFrame* dst_frame, AVCodecContextPtr& c, AVStream* st);
void flush_stream(OutputStream& os);
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment