Commit fce6180c authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Tweak OutputStream implementation (#3122)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3122

- Remove manual tracking of num_frames
- Remove unnecessary dispatch in AudioOutputStream

Reviewed By: nateanl

Differential Revision: D43685746

fbshipit-source-id: a7e62a81549fb62ad0caa3b741655eba3bc5e250
parent 0bf00d20
...@@ -34,6 +34,7 @@ AVFramePtr get_audio_frame( ...@@ -34,6 +34,7 @@ AVFramePtr get_audio_frame(
AVCodecContext* codec_ctx, AVCodecContext* codec_ctx,
int default_frame_size = 10000) { int default_frame_size = 10000) {
AVFramePtr frame{}; AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt; frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout; frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate; frame->sample_rate = codec_ctx->sample_rate;
...@@ -121,31 +122,23 @@ void AudioOutputStream::write_chunk(const torch::Tensor& waveform) { ...@@ -121,31 +122,23 @@ void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
AVRational time_base{1, codec_ctx->sample_rate}; AVRational time_base{1, codec_ctx->sample_rate};
using namespace torch::indexing; using namespace torch::indexing;
AT_DISPATCH_ALL_TYPES(waveform.scalar_type(), "write_audio_frames", [&] { for (int64_t i = 0; i < waveform.size(0); i += frame_capacity) {
for (int64_t i = 0; i < waveform.size(0); i += frame_capacity) { auto chunk = waveform.index({Slice(i, i + frame_capacity), Slice()});
auto chunk = waveform.index({Slice(i, i + frame_capacity), Slice()}); auto num_frames = chunk.size(0);
auto num_valid_frames = chunk.size(0); auto byte_size = chunk.numel() * chunk.element_size();
auto byte_size = chunk.numel() * chunk.element_size(); chunk = chunk.reshape({-1}).contiguous();
chunk = chunk.reshape({-1}).contiguous();
// TODO: make writable
// TODO: make writable // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334 TORCH_CHECK(
TORCH_CHECK( av_frame_is_writable(src_frame),
av_frame_is_writable(src_frame), "Internal Error: frame is not writable.");
"Internal Error: frame is not writable.");
memcpy(src_frame->data[0], chunk.data_ptr(), byte_size);
memcpy( src_frame->pts += av_rescale_q(num_frames, time_base, codec_ctx->time_base);
src_frame->data[0], src_frame->nb_samples = num_frames;
static_cast<void*>(chunk.data_ptr<scalar_t>()), process_frame(src_frame);
byte_size); }
src_frame->pts =
av_rescale_q(num_frames, time_base, codec_ctx->time_base);
src_frame->nb_samples = num_valid_frames;
num_frames += num_valid_frames;
process_frame(src_frame);
}
});
} }
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -9,8 +9,7 @@ OutputStream::OutputStream( ...@@ -9,8 +9,7 @@ OutputStream::OutputStream(
: codec_ctx(codec_ctx_), : codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx), encoder(format_ctx, codec_ctx),
filter(std::move(filter_)), filter(std::move(filter_)),
dst_frame(), dst_frame() {}
num_frames(0) {}
void OutputStream::process_frame(AVFrame* src) { void OutputStream::process_frame(AVFrame* src) {
int ret = filter.add_frame(src); int ret = filter.add_frame(src);
......
...@@ -16,8 +16,6 @@ struct OutputStream { ...@@ -16,8 +16,6 @@ struct OutputStream {
FilterGraph filter; FilterGraph filter;
// frame that output from FilterGraph is written // frame that output from FilterGraph is written
AVFramePtr dst_frame; AVFramePtr dst_frame;
// The number of samples written so far
int64_t num_frames;
OutputStream( OutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
......
...@@ -219,6 +219,7 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { ...@@ -219,6 +219,7 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
av_err2string(ret), av_err2string(ret),
")."); ").");
} }
frame->pts = 0;
return frame; return frame;
} }
......
...@@ -52,8 +52,7 @@ VideoOutputStream::VideoOutputStream( ...@@ -52,8 +52,7 @@ VideoOutputStream::VideoOutputStream(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) { void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
for (const auto& frame : converter.convert(frames)) { for (const auto& frame : converter.convert(frames)) {
frame->pts = num_frames; frame->pts += 1;
num_frames += 1;
process_frame(frame); process_frame(frame);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment