Commit 8d2f6f8d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Support overwriting PTS in StreamWriter (#3135)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3135

Reviewed By: xiaohui-zhang

Differential Revision: D43724273

Pulled By: mthrok

fbshipit-source-id: 9b52823618948945a26e57d5b3deccbf5f9268c1
parent 3212a257
......@@ -59,6 +59,7 @@ class ChunkTensorTest(TorchaudioTestCase):
w.add_audio_stream(8000, 2)
with w.open():
w.write_audio_chunk(0, c)
w.write_audio_chunk(0, c, c.pts)
################################################################################
......
import math
import torch
import torchaudio
......@@ -435,3 +437,53 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
num_samples += chunk.size(0)
print(chunk.pts, expected)
assert abs(chunk.pts - expected) < 1e-10
@parameterized.expand(
[
(10, 100),
(15, 150),
(24, 240),
(25, 200),
(30, 300),
(50, 500),
(60, 600),
# PTS value conversion involves float <-> int conversion, which can
# introduce rounding error.
# This test is a spot-check for popular 29.97 Hz
(30000 / 1001, 10010),
]
)
def test_video_pts_overwrite(self, frame_rate, num_frames):
"""Can overwrite PTS"""
ext = "mp4"
filename = f"test.{ext}"
width, height = 8, 8
# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
video = torch.zeros((1, 3, height, width), dtype=torch.uint8)
reference_pts = []
with writer.open():
for i in range(num_frames):
pts = i / frame_rate
reference_pts.append(pts)
writer.write_video_chunk(0, video, pts)
# check
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
assert len(pts) == len(reference_pts)
for val, ref in zip(pts, reference_pts):
# torch provides isclose, but we don't know if converting floats to tensor
# could introduce a descrepancy, so we compare floats and use math.isclose
# for that.
assert math.isclose(val, ref)
......@@ -218,8 +218,8 @@ AVFramePtr get_audio_frame(
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
// note: channels attribute is not required for encoding, but TensorConverter
// refers to it
// Note: `channels` attribute is not required for encoding, but
// TensorConverter refers to it
frame->channels = num_channels;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = sample_rate;
......@@ -461,6 +461,10 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
av_err2string(ret),
").");
}
// Note: `nb_samples` attribute is not used for video, but we set it
// anyways so that we can make the logic of PTS increment agnostic to
// audio and video.
frame->nb_samples = 1;
frame->pts = 0;
return frame;
}
......@@ -511,7 +515,10 @@ EncodeProcess::EncodeProcess(
src_frame(get_video_frame(format, codec_ctx)),
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
void EncodeProcess::process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
TORCH_CHECK(
codec_ctx->codec_type == type,
"Attempted to write ",
......@@ -519,16 +526,18 @@ void EncodeProcess::process(AVMediaType type, const torch::Tensor& tensor) {
" to ",
av_get_media_type_string(codec_ctx->codec_type),
" stream.");
AVRational codec_tb = codec_ctx->time_base;
if (pts) {
AVRational tb = codec_ctx->time_base;
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) {
TORCH_WARN_ONCE(
"The provided PTS value is smaller than the next expected value.");
}
src_frame->pts = val;
}
for (const auto& frame : converter.convert(tensor)) {
process_frame(frame);
if (type == AVMEDIA_TYPE_VIDEO) {
frame->pts += 1;
} else {
AVRational sr_tb{1, codec_ctx->sample_rate};
frame->pts += av_rescale_q(frame->nb_samples, sr_tb, codec_tb);
}
frame->pts += frame->nb_samples;
}
}
......
......@@ -39,7 +39,10 @@ class EncodeProcess {
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel);
void process(AVMediaType type, const torch::Tensor& tensor);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);
void process_frame(AVFrame* src);
......
......@@ -200,7 +200,10 @@ void StreamWriter::close() {
is_open = false;
}
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
void StreamWriter::write_audio_chunk(
int i,
const torch::Tensor& waveform,
const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
......@@ -208,10 +211,13 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
processes.size(),
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts);
}
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
void StreamWriter::write_video_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
......@@ -219,7 +225,7 @@ void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
processes.size(),
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts);
}
void StreamWriter::flush() {
......
......@@ -162,14 +162,42 @@ class StreamWriter {
/// @param i Stream index.
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``.
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
void write_audio_chunk(int i, const torch::Tensor& chunk);
/// @param pts
/// @parblock
/// Presentation timestamp. If provided, it overwrites the PTS of
/// the first frame with the provided one. Otherwise, PTS are incremented per
/// an inverse of sample rate. Only values exceed the PTS values processed
/// internally.
///
/// __NOTE__: The provided value is converted to integer value expressed
/// in basis of sample rate.
/// Therefore, it is truncated to the nearest value of ``n / sample_rate``.
/// @endparblock
void write_audio_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = {});
/// Write video data
/// @param i Stream index.
/// @param chunk Video/image tensor. Shape: ``(time, channel, height,
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
/// width and the number of channels)`` must match what was configured when
/// calling ``add_video_stream()``.
void write_video_chunk(int i, const torch::Tensor& chunk);
/// @param pts
/// @parblock
/// Presentation timestamp. If provided, it overwrites the PTS of
/// the first frame with the provided one. Otherwise, PTS are incremented per
/// an inverse of frame rate. Only values exceed the PTS values processed
/// internally.
///
/// __NOTE__: The provided value is converted to integer value expressed
/// in basis of frame rate.
/// Therefore, it is truncated to the nearest value of ``n / frame_rate``.
/// @endparblock
void write_video_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = {});
/// Flush the frames from encoders and write the frames to the destination.
void flush();
};
......
......@@ -275,17 +275,24 @@ class StreamWriter:
self._s.close()
self._is_open = False
def write_audio_chunk(self, i: int, chunk: torch.Tensor):
def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
"""Write audio data
Args:
i (int): Stream index.
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
pts (float, optional, or None): If provided, overwrite the presentation timestamp.
.. note::
The provided value is converted to integer value expressed in basis of
sample rate. Therefore, it is truncated to the nearest value of
``n / sample_rate``.
"""
self._s.write_audio_chunk(i, chunk)
self._s.write_audio_chunk(i, chunk, pts)
def write_video_chunk(self, i: int, chunk: torch.Tensor):
def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
"""Write video/image data
Args:
......@@ -295,8 +302,15 @@ class StreamWriter:
The ``dtype`` must be ``torch.uint8``.
The shape (height, width and the number of channels) must match
what was configured when calling :py:meth:`add_video_stream`
pts (float, optional or None): If provided, overwrite the presentation timestamp.
.. note::
The provided value is converted to integer value expressed in basis of
frame rate. Therefore, it is truncated to the nearest value of
``n / frame_rate``.
"""
self._s.write_video_chunk(i, chunk)
self._s.write_video_chunk(i, chunk, pts)
def flush(self):
"""Flush the frames from encoders and write the frames to the destination."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment