Commit 49287210 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Cleaning up private methods (#3030)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3030

A part of StreamWriter refactoring

(Note: this ignores all push blocking failures!)

Reviewed By: hwangjeff

Differential Revision: D42905959

fbshipit-source-id: ba8add3ce549c70c3775640840e41ace06b0ef65
parent f663cb28
...@@ -666,36 +666,26 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) { ...@@ -666,36 +666,26 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) {
av_get_media_type_string(type)); av_get_media_type_string(type));
} }
void StreamWriter::process_frame( namespace {
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVCodecContextPtr& c,
AVStream* st) {
int ret = filter->add_frame(src_frame);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encode_frame(nullptr, c, st);
}
break;
}
if (ret >= 0) {
encode_frame(dst_frame, c, st);
}
av_frame_unref(dst_frame);
}
}
void StreamWriter::encode_frame( ///
/// Encode the given AVFrame data
///
/// @param frame Frame data to encode
/// @param format Output format context
/// @param stream Output stream in the output format context
/// @param codec Encoding context
/// @param packet Temporaly packet used during encoding.
void encode_frame(
AVFrame* frame, AVFrame* frame,
AVCodecContextPtr& c, AVFormatContext* format,
AVStream* st) { AVStream* stream,
int ret = avcodec_send_frame(c, frame); AVCodecContext* codec,
AVPacket* packet) {
int ret = avcodec_send_frame(codec, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ")."); TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) { while (ret >= 0) {
ret = avcodec_receive_packet(c, pkt); ret = avcodec_receive_packet(codec, packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) { if (ret == AVERROR_EOF) {
// Note: // Note:
...@@ -710,7 +700,7 @@ void StreamWriter::encode_frame( ...@@ -710,7 +700,7 @@ void StreamWriter::encode_frame(
// An alternative is to use `av_write_frame` functoin, but in that case // An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it // client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter // complicated to use StreamWriter
ret = av_interleaved_write_frame(pFormatContext, nullptr); ret = av_interleaved_write_frame(format, nullptr);
TORCH_CHECK( TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ")."); ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
} }
...@@ -725,20 +715,43 @@ void StreamWriter::encode_frame( ...@@ -725,20 +715,43 @@ void StreamWriter::encode_frame(
// https://github.com/pytorch/audio/issues/2790 // https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as // If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish. // the encoder cannot figure out when the packet should finish.
if (pkt->duration == 0 && c->codec_type == AVMEDIA_TYPE_VIDEO) { if (packet->duration == 0 && codec->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate) // 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow. // This has to be set before av_packet_rescale_ts bellow.
pkt->duration = 1; packet->duration = 1;
} }
av_packet_rescale_ts(pkt, c->time_base, st->time_base); av_packet_rescale_ts(packet, codec->time_base, stream->time_base);
pkt->stream_index = st->index; packet->stream_index = stream->index;
ret = av_interleaved_write_frame(pFormatContext, pkt); ret = av_interleaved_write_frame(format, packet);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ")."); TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
} }
} }
namespace { void process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVFormatContext* format,
AVStream* stream,
AVCodecContextPtr& codec,
AVPacket* packet) {
int ret = filter->add_frame(src_frame);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encode_frame(nullptr, format, stream, codec, packet);
}
break;
}
if (ret >= 0) {
encode_frame(dst_frame, format, stream, codec, packet);
}
av_frame_unref(dst_frame);
}
}
void validate_audio_input( void validate_audio_input(
enum AVSampleFormat fmt, enum AVSampleFormat fmt,
AVCodecContext* ctx, AVCodecContext* ctx,
...@@ -812,6 +825,7 @@ void validate_video_input( ...@@ -812,6 +825,7 @@ void validate_video_input(
") (NCHW format). Found ", ") (NCHW format). Found ",
t.sizes()); t.sizes());
} }
} // namespace } // namespace
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) { void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
...@@ -853,9 +867,16 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) { ...@@ -853,9 +867,16 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
if (os.filter) { if (os.filter) {
process_frame( process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream); os.src_frame,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else { } else {
encode_frame(os.src_frame, os.codec_ctx, os.stream); encode_frame(
os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
} }
} }
}); });
...@@ -940,7 +961,7 @@ void StreamWriter::write_interlaced_video_cuda( ...@@ -940,7 +961,7 @@ void StreamWriter::write_interlaced_video_cuda(
} }
os.src_frame->pts = os.num_frames; os.src_frame->pts = os.num_frames;
os.num_frames += 1; os.num_frames += 1;
encode_frame(os.src_frame, os.codec_ctx, os.stream); encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
} }
} }
...@@ -971,7 +992,7 @@ void StreamWriter::write_planar_video_cuda( ...@@ -971,7 +992,7 @@ void StreamWriter::write_planar_video_cuda(
} }
os.src_frame->pts = os.num_frames; os.src_frame->pts = os.num_frames;
os.num_frames += 1; os.num_frames += 1;
encode_frame(os.src_frame, os.codec_ctx, os.stream); encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
} }
} }
#endif #endif
...@@ -1020,9 +1041,15 @@ void StreamWriter::write_interlaced_video( ...@@ -1020,9 +1041,15 @@ void StreamWriter::write_interlaced_video(
if (os.filter) { if (os.filter) {
process_frame( process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream); os.src_frame,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else { } else {
encode_frame(os.src_frame, os.codec_ctx, os.stream); encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
} }
} }
} }
...@@ -1080,9 +1107,15 @@ void StreamWriter::write_planar_video( ...@@ -1080,9 +1107,15 @@ void StreamWriter::write_planar_video(
if (os.filter) { if (os.filter) {
process_frame( process_frame(
os.src_frame, os.filter, os.dst_frame, os.codec_ctx, os.stream); os.src_frame,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else { } else {
encode_frame(os.src_frame, os.codec_ctx, os.stream); encode_frame(os.src_frame, pFormatContext, os.stream, os.codec_ctx, pkt);
} }
} }
} }
...@@ -1095,9 +1128,16 @@ void StreamWriter::flush() { ...@@ -1095,9 +1128,16 @@ void StreamWriter::flush() {
void StreamWriter::flush_stream(OutputStream& os) { void StreamWriter::flush_stream(OutputStream& os) {
if (os.filter) { if (os.filter) {
process_frame(nullptr, os.filter, os.dst_frame, os.codec_ctx, os.stream); process_frame(
nullptr,
os.filter,
os.dst_frame,
pFormatContext,
os.stream,
os.codec_ctx,
pkt);
} else { } else {
encode_frame(nullptr, os.codec_ctx, os.stream); encode_frame(nullptr, pFormatContext, os.stream, os.codec_ctx, pkt);
} }
} }
} // namespace io } // namespace io
......
...@@ -210,13 +210,6 @@ class StreamWriter { ...@@ -210,13 +210,6 @@ class StreamWriter {
const torch::Tensor& chunk, const torch::Tensor& chunk,
bool pad_extra = true); bool pad_extra = true);
#endif #endif
void process_frame(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVCodecContextPtr& c,
AVStream* st);
void encode_frame(AVFrame* dst_frame, AVCodecContextPtr& c, AVStream* st);
void flush_stream(OutputStream& os); void flush_stream(OutputStream& os);
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment