Commit fce6180c authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Tweak OutputStream implementation (#3122)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3122

- Remove manual tracking of num_frames
- Remove unnecessary dispatch in AudioOutputStream

Reviewed By: nateanl

Differential Revision: D43685746

fbshipit-source-id: a7e62a81549fb62ad0caa3b741655eba3bc5e250
parent 0bf00d20
......@@ -34,6 +34,7 @@ AVFramePtr get_audio_frame(
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
......@@ -121,31 +122,23 @@ void AudioOutputStream::write_chunk(const torch::Tensor& waveform) {
AVRational time_base{1, codec_ctx->sample_rate};
using namespace torch::indexing;
AT_DISPATCH_ALL_TYPES(waveform.scalar_type(), "write_audio_frames", [&] {
for (int64_t i = 0; i < waveform.size(0); i += frame_capacity) {
auto chunk = waveform.index({Slice(i, i + frame_capacity), Slice()});
auto num_valid_frames = chunk.size(0);
auto byte_size = chunk.numel() * chunk.element_size();
chunk = chunk.reshape({-1}).contiguous();
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(
av_frame_is_writable(src_frame),
"Internal Error: frame is not writable.");
memcpy(
src_frame->data[0],
static_cast<void*>(chunk.data_ptr<scalar_t>()),
byte_size);
src_frame->pts =
av_rescale_q(num_frames, time_base, codec_ctx->time_base);
src_frame->nb_samples = num_valid_frames;
num_frames += num_valid_frames;
process_frame(src_frame);
}
});
for (int64_t i = 0; i < waveform.size(0); i += frame_capacity) {
auto chunk = waveform.index({Slice(i, i + frame_capacity), Slice()});
auto num_frames = chunk.size(0);
auto byte_size = chunk.numel() * chunk.element_size();
chunk = chunk.reshape({-1}).contiguous();
// TODO: make writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
TORCH_CHECK(
av_frame_is_writable(src_frame),
"Internal Error: frame is not writable.");
memcpy(src_frame->data[0], chunk.data_ptr(), byte_size);
src_frame->pts += av_rescale_q(num_frames, time_base, codec_ctx->time_base);
src_frame->nb_samples = num_frames;
process_frame(src_frame);
}
}
} // namespace torchaudio::io
......@@ -9,8 +9,7 @@ OutputStream::OutputStream(
: codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx),
filter(std::move(filter_)),
dst_frame(),
num_frames(0) {}
dst_frame() {}
void OutputStream::process_frame(AVFrame* src) {
int ret = filter.add_frame(src);
......
......@@ -16,8 +16,6 @@ struct OutputStream {
FilterGraph filter;
// frame that output from FilterGraph is written
AVFramePtr dst_frame;
// The number of samples written so far
int64_t num_frames;
OutputStream(
AVFormatContext* format_ctx,
......
......@@ -219,6 +219,7 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
......
......@@ -52,8 +52,7 @@ VideoOutputStream::VideoOutputStream(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
for (const auto& frame : converter.convert(frames)) {
frame->pts = num_frames;
num_frames += 1;
frame->pts += 1;
process_frame(frame);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment