Commit 0dd59e0d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Make StreamReader return PTS (#2975)

Summary:
This commit makes `StreamReader` report PTS (presentation time stamp) of the returned chunk as well.

Example

```python
from torchaudio.io import StreamReader

s = StreamReader(...)
s.add_video_stream(...)
for (video_chunk, ) in s.stream():
    # video_chunk is Torch tensor type but has extra attribute of PTS
    print(video_chunk.pts)  # reports the PTS of the first frame of the video chunk.
```

For the backward compatibility, we introduce a `_ChunkTensor`, that is a composition
of Tensor and metadata, but works like a normal tensor in PyTorch operations.

The implementation of `_ChunkTensor` is based on [TrivialTensorViaComposition](https://github.com/albanD/subclass_zoo/blob/0eeb1d68fb59879029c610bc407f2997ae43ba0a/trivial_tensors.py#L83).

It was also suggested to attach metadata directly to Tensor object,
but the possibility to have the collision on torchaudio's metadata and new attributes introduced in
PyTorch cannot be ignored, so we use Tensor subclass implementation.

If any unexpected issue arise from metadata attribute name collision, client code can
fetch the bare Tensor and continue.

Pull Request resolved: https://github.com/pytorch/audio/pull/2975

Reviewed By: hwangjeff

Differential Revision: D42526945

Pulled By: mthrok

fbshipit-source-id: b4e9422e914ff328421b975120460f3001268f35
parent de628226
......@@ -52,11 +52,17 @@ Methods
Support Structures
==================
{%- for item in ["StreamReaderSourceStream", "StreamReaderSourceAudioStream", "StreamReaderSourceVideoStream", "StreamReaderOutputStream"] %}
{%- for item in [
"ChunkTensor",
"SourceStream",
"SourceAudioStream",
"SourceVideoStream",
"OutputStream",
] %}
{{ item | underline("-") }}
.. autoclass:: torchaudio.io.{{item}}()
.. autoclass:: torchaudio.io._stream_reader.{{item}}()
:members:
{%- endfor %}
......
import io
import torch
import torchaudio
from parameterized import parameterized, parameterized_class
......@@ -17,12 +19,46 @@ from torchaudio_unittest.common_utils import (
)
if is_ffmpeg_available():
from torchaudio.io import (
StreamReader,
StreamReaderSourceAudioStream,
StreamReaderSourceStream,
StreamReaderSourceVideoStream,
)
from torchaudio.io import StreamReader, StreamWriter
from torchaudio.io._stream_reader import ChunkTensor, SourceAudioStream, SourceStream, SourceVideoStream
@skipIfNoFFmpeg
class ChunkTensorTest(TorchaudioTestCase):
def test_chunktensor(self):
"""ChunkTensor serves as a replacement of tensor"""
data = torch.randn((256, 2))
pts = 16.0
c = ChunkTensor(data, pts)
assert c.pts == pts
self.assertEqual(c, data)
# method
sum_ = c.sum()
assert isinstance(sum_, torch.Tensor)
self.assertEqual(sum_, data.sum())
# function form
min_ = torch.min(c)
assert isinstance(min_, torch.Tensor)
self.assertEqual(min_, torch.min(data))
# attribute
t = c.T
assert isinstance(t, torch.Tensor)
self.assertEqual(t, data.T)
# in-place op
c[0] = 0
self.assertEqual(c, data)
# pass to other C++ code
buffer = io.BytesIO()
w = StreamWriter(buffer, format="wav")
w.add_audio_stream(8000, 2)
with w.open():
w.write_audio_chunk(0, c)
################################################################################
......@@ -109,7 +145,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
base_metadata = {}
expected = [
StreamReaderSourceVideoStream(
SourceVideoStream(
media_type="video",
codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
......@@ -126,7 +162,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height=180,
frame_rate=25.0,
),
StreamReaderSourceAudioStream(
SourceAudioStream(
media_type="audio",
codec="aac",
codec_long_name="AAC (Advanced Audio Coding)",
......@@ -142,7 +178,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate=8000.0,
num_channels=2,
),
StreamReaderSourceStream(
SourceStream(
media_type="subtitle",
codec="mov_text",
codec_long_name="MOV text",
......@@ -155,7 +191,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
"language": "eng",
},
),
StreamReaderSourceVideoStream(
SourceVideoStream(
media_type="video",
codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
......@@ -172,7 +208,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height=270,
frame_rate=29.97002997002997,
),
StreamReaderSourceAudioStream(
SourceAudioStream(
media_type="audio",
codec="aac",
codec_long_name="AAC (Advanced Audio Coding)",
......@@ -188,7 +224,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate=16000.0,
num_channels=2,
),
StreamReaderSourceStream(
SourceStream(
media_type="subtitle",
codec="mov_text",
codec_long_name="MOV text",
......@@ -605,6 +641,59 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([896, 2])
@parameterized.expand([(1,), (3,), (5,), (10,)])
def test_video_pts(self, fpc):
"""PTS values of the first frame are reported in .pts attribute"""
rate, num_frames = 30000 / 1001, 390
ref_pts = [i / rate for i in range(0, num_frames, fpc)]
s = StreamReader(self.get_src())
s.add_video_stream(fpc)
pts = [video.pts for video, in s.stream()]
self.assertEqual(pts, ref_pts)
@parameterized.expand([(256,), (512,), (1024,), (4086,)])
def test_audio_pts(self, fpc):
"""PTS values of the first frame are reported in .pts attribute"""
rate, num_frames = 16000, 208896
ref_pts = [i / rate for i in range(0, num_frames, fpc)]
s = StreamReader(self.get_src())
s.add_audio_stream(fpc, buffer_chunk_size=-1)
pts = [audio.pts for audio, in s.stream()]
self.assertEqual(pts, ref_pts)
def test_pts_unchunked_process_all(self):
"""PTS is zero when loading the entire media with unchunked buffer"""
s = StreamReader(self.get_src())
s.add_audio_stream(-1, buffer_chunk_size=-1)
s.add_video_stream(-1, buffer_chunk_size=-1)
s.process_all_packets()
audio, video = s.pop_chunks()
assert audio.pts == 0.0
assert video.pts == 0.0
assert audio.size(0) == 208896
assert video.size(0) == 390
def test_pts_unchunked(self):
"""PTS grows proportionally to the number of frames decoded"""
s = StreamReader(self.get_src())
s.add_audio_stream(-1, buffer_chunk_size=-1)
s.add_video_stream(-1, buffer_chunk_size=-1)
num_audio_frames, num_video_frames = 0, 0
while num_audio_frames < 208896 and num_video_frames < 390:
s.process_packet()
audio, video = s.pop_chunks()
if audio is None and video is None:
continue
if audio is not None:
assert audio.pts == num_audio_frames / 16000
num_audio_frames += audio.size(0)
if video is not None:
assert video.pts == num_video_frames * 1001 / 30000
num_video_frames += video.size(0)
def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]"""
......
#pragma once
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio {
namespace ffmpeg {
......@@ -21,9 +22,9 @@ class Buffer {
//////////////////////////////////////////////////////////////////////////////
// Modifiers
//////////////////////////////////////////////////////////////////////////////
virtual void push_frame(AVFrame* frame) = 0;
virtual void push_frame(AVFrame* frame, double pts) = 0;
virtual c10::optional<torch::Tensor> pop_chunk() = 0;
virtual c10::optional<Chunk> pop_chunk() = 0;
virtual void flush() = 0;
};
......
......@@ -5,23 +5,33 @@ namespace torchaudio {
namespace ffmpeg {
namespace detail {
ChunkedBuffer::ChunkedBuffer(int frames_per_chunk, int num_chunks)
: frames_per_chunk(frames_per_chunk), num_chunks(num_chunks) {}
ChunkedBuffer::ChunkedBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration)
: frame_duration(frame_duration),
frames_per_chunk(frames_per_chunk),
num_chunks(num_chunks) {}
ChunkedAudioBuffer::ChunkedAudioBuffer(int frames_per_chunk, int num_chunks)
: ChunkedBuffer(frames_per_chunk, num_chunks) {}
ChunkedAudioBuffer::ChunkedAudioBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration)
: ChunkedBuffer(frames_per_chunk, num_chunks, frame_duration) {}
ChunkedVideoBuffer::ChunkedVideoBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
const torch::Device& device_)
: ChunkedBuffer(frames_per_chunk, num_chunks), device(device_) {}
: ChunkedBuffer(frames_per_chunk, num_chunks, frame_duration),
device(device_) {}
bool ChunkedBuffer::is_ready() const {
return num_buffered_frames >= frames_per_chunk;
}
void ChunkedBuffer::push_tensor(torch::Tensor frame) {
void ChunkedBuffer::push_tensor(torch::Tensor frame, double pts_) {
using namespace torch::indexing;
// Note:
// Audio tensors contain multiple frames while video tensors contain only
......@@ -60,6 +70,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
num_buffered_frames += append;
// frame = frame[append:]
frame = frame.index({Slice(append)});
pts_ += double(append) * frame_duration;
}
// 2. Return if the number of input frames are smaller than the empty buffer.
......@@ -83,6 +94,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
int64_t start = i * frames_per_chunk;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto chunk = frame.index({Slice(start, start + frames_per_chunk)});
double pts_val = pts_ + double(start) * frame_duration;
int64_t chunk_size = chunk.size(0);
TORCH_INTERNAL_ASSERT(
chunk_size <= frames_per_chunk,
......@@ -95,6 +107,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
chunk = temp;
}
chunks.push_back(chunk);
pts.push_back(pts_val);
num_buffered_frames += chunk_size;
// Trim if num_chunks > 0
......@@ -109,26 +122,28 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
}
}
c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() {
c10::optional<Chunk> ChunkedBuffer::pop_chunk() {
using namespace torch::indexing;
if (!num_buffered_frames) {
return {};
}
torch::Tensor ret = chunks.front();
torch::Tensor chunk = chunks.front();
double pts_val = pts.front();
chunks.pop_front();
pts.pop_front();
if (num_buffered_frames < frames_per_chunk) {
ret = ret.index({Slice(None, num_buffered_frames)});
chunk = chunk.index({Slice(None, num_buffered_frames)});
}
num_buffered_frames -= ret.size(0);
return c10::optional<torch::Tensor>{ret};
num_buffered_frames -= chunk.size(0);
return {Chunk{chunk, pts_val}};
}
void ChunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_audio(frame));
void ChunkedAudioBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_audio(frame), pts_);
}
void ChunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_image(frame, device));
void ChunkedVideoBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_image(frame, device), pts_);
}
void ChunkedBuffer::flush() {
......
......@@ -13,6 +13,10 @@ namespace detail {
class ChunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
// Time stamps corresponding the first frame of each chunk
std::deque<double> pts;
// Duration of one frame, used to recalculate the PTS of audio samples
double frame_duration;
// The number of frames to return as a chunk
// If <0, then user wants to receive all the frames
......@@ -25,21 +29,24 @@ class ChunkedBuffer : public Buffer {
int64_t num_buffered_frames = 0;
protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks);
ChunkedBuffer(int frames_per_chunk, int num_chunks, double frame_duration);
void push_tensor(torch::Tensor frame);
void push_tensor(torch::Tensor frame, double pts);
public:
bool is_ready() const override;
void flush() override;
c10::optional<torch::Tensor> pop_chunk() override;
c10::optional<Chunk> pop_chunk() override;
};
class ChunkedAudioBuffer : public ChunkedBuffer {
public:
ChunkedAudioBuffer(int frames_per_chunk, int num_chunks);
ChunkedAudioBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration);
void push_frame(AVFrame* frame) override;
void push_frame(AVFrame* frame, double pts) override;
};
class ChunkedVideoBuffer : public ChunkedBuffer {
......@@ -49,9 +56,10 @@ class ChunkedVideoBuffer : public ChunkedBuffer {
ChunkedVideoBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration,
const torch::Device& device);
void push_frame(AVFrame* frame) override;
void push_frame(AVFrame* frame, double pts) override;
};
} // namespace detail
......
......@@ -12,27 +12,30 @@ bool UnchunkedBuffer::is_ready() const {
return chunks.size() > 0;
}
void UnchunkedBuffer::push_tensor(const torch::Tensor& t) {
void UnchunkedBuffer::push_tensor(const torch::Tensor& t, double pts_) {
if (chunks.size() == 0) {
pts = pts_;
}
chunks.push_back(t);
}
void UnchunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_audio(frame));
void UnchunkedAudioBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_audio(frame), pts_);
}
void UnchunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_image(frame, device));
void UnchunkedVideoBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_image(frame, device), pts_);
}
c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() {
c10::optional<Chunk> UnchunkedBuffer::pop_chunk() {
if (chunks.size() == 0) {
return {};
}
auto ret =
auto frames =
torch::cat(std::vector<torch::Tensor>{chunks.begin(), chunks.end()}, 0);
chunks.clear();
return {ret};
return {Chunk{frames, pts}};
}
void UnchunkedBuffer::flush() {
......
......@@ -16,19 +16,20 @@ namespace detail {
class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
double pts = -1.;
protected:
void push_tensor(const torch::Tensor& t);
void push_tensor(const torch::Tensor& t, double pts);
public:
bool is_ready() const override;
c10::optional<torch::Tensor> pop_chunk() override;
c10::optional<Chunk> pop_chunk() override;
void flush() override;
};
class UnchunkedAudioBuffer : public UnchunkedBuffer {
public:
void push_frame(AVFrame* frame) override;
void push_frame(AVFrame* frame, double pts) override;
};
class UnchunkedVideoBuffer : public UnchunkedBuffer {
......@@ -37,7 +38,7 @@ class UnchunkedVideoBuffer : public UnchunkedBuffer {
public:
explicit UnchunkedVideoBuffer(const torch::Device& device);
void push_frame(AVFrame* frame) override;
void push_frame(AVFrame* frame, double pts) override;
};
} // namespace detail
......
......@@ -11,6 +11,7 @@ std::unique_ptr<Buffer> get_buffer(
AVMediaType type,
int frames_per_chunk,
int num_chunks,
double frame_duration,
const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
......@@ -31,11 +32,11 @@ std::unique_ptr<Buffer> get_buffer(
// Chunked Mode
if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(
new detail::ChunkedAudioBuffer(frames_per_chunk, num_chunks));
return std::unique_ptr<Buffer>(new detail::ChunkedAudioBuffer(
frames_per_chunk, num_chunks, frame_duration));
} else {
return std::unique_ptr<Buffer>(
new detail::ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
return std::unique_ptr<Buffer>(new detail::ChunkedVideoBuffer(
frames_per_chunk, num_chunks, frame_duration, device));
}
} else { // unchunked mode
if (type == AVMEDIA_TYPE_AUDIO) {
......@@ -91,10 +92,12 @@ Sink::Sink(
filter_description(filter_description_.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
filter(get_filter_graph(input_time_base_, codecpar_, filter_description)),
output_time_base(filter->get_output_timebase()),
buffer(get_buffer(
codecpar_->codec_type,
frames_per_chunk,
num_chunks,
double(output_time_base.num) / output_time_base.den,
device)) {}
// 0: some kind of success
......@@ -109,7 +112,9 @@ int Sink::process_frame(AVFrame* pFrame) {
return 0;
}
if (ret >= 0) {
buffer->push_frame(frame);
double pts =
double(frame->pts * output_time_base.num) / output_time_base.den;
buffer->push_frame(frame, pts);
}
av_frame_unref(frame);
}
......
......@@ -15,6 +15,8 @@ class Sink {
AVCodecParameters* codecpar;
std::string filter_description;
std::unique_ptr<FilterGraph> filter;
// time_base of filter graph output, used for PTS calc
AVRational output_time_base;
public:
std::unique_ptr<Buffer> buffer;
......
......@@ -151,7 +151,7 @@ int StreamProcessor::send_frame(AVFrame* pFrame) {
////////////////////////////////////////////////////////////////////////////////
// Retrieval
////////////////////////////////////////////////////////////////////////////////
c10::optional<torch::Tensor> StreamProcessor::pop_chunk(KeyType key) {
c10::optional<Chunk> StreamProcessor::pop_chunk(KeyType key) {
return sinks.at(key).buffer->pop_chunk();
}
......
......@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/decoder.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/sink.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <map>
namespace torchaudio {
......@@ -95,7 +96,7 @@ class StreamProcessor {
//////////////////////////////////////////////////////////////////////////////
public:
// Get the chunk from the given filter result
c10::optional<torch::Tensor> pop_chunk(KeyType key);
c10::optional<Chunk> pop_chunk(KeyType key);
};
} // namespace ffmpeg
......
......@@ -385,8 +385,21 @@ int StreamReader::drain() {
std::vector<c10::optional<torch::Tensor>> StreamReader::pop_chunks() {
std::vector<c10::optional<torch::Tensor>> ret;
for (auto& c : pop_chunks_with_metadata()) {
if (c) {
ret.emplace_back(c->frames);
} else {
ret.emplace_back();
}
}
return ret;
}
std::vector<c10::optional<Chunk>> StreamReader::pop_chunks_with_metadata() {
std::vector<c10::optional<Chunk>> ret;
ret.reserve(num_out_streams());
for (auto& i : stream_indices) {
ret.push_back(processors[i.first]->pop_chunk(i.second));
ret.emplace_back(processors[i.first]->pop_chunk(i.second));
}
return ret;
}
......
......@@ -263,6 +263,10 @@ class StreamReader {
std::vector<c10::optional<torch::Tensor>> pop_chunks();
///@}
/// Pop one chunk from each output stream if it is available.
/// TODO: merge this to pop_chunks
std::vector<c10::optional<Chunk>> pop_chunks_with_metadata();
};
} // namespace ffmpeg
......
......@@ -103,5 +103,18 @@ int64_t StreamReaderBinding::fill_buffer(
return 0;
}
std::vector<c10::optional<ChunkData>> StreamReaderBinding::pop_chunks() {
std::vector<c10::optional<ChunkData>> ret;
ret.reserve(static_cast<size_t>(num_out_streams()));
for (auto& c : StreamReader::pop_chunks_with_metadata()) {
if (c) {
ret.emplace_back(std::forward_as_tuple(c->frames, c->pts));
} else {
ret.emplace_back();
}
}
return ret;
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -61,6 +61,8 @@ using OutInfo = std::tuple<
std::string // filter description
>;
using ChunkData = std::tuple<torch::Tensor, double>;
// Structure to implement wrapper API around StreamReader, which is more
// suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public StreamReader,
......@@ -78,6 +80,8 @@ struct StreamReaderBinding : public StreamReader,
int64_t fill_buffer(
const c10::optional<double>& timeout = c10::optional<double>(),
const double backoff = 10.);
std::vector<c10::optional<ChunkData>> pop_chunks();
};
} // namespace ffmpeg
......
......@@ -104,5 +104,20 @@ struct OutputStreamInfo {
std::string filter_description;
};
/// Stores decoded frames and metadata
struct Chunk {
/// Audio/video frames.
///
/// For audio, the shape is ``[time, num_channels]``, and the ``dtype``
/// depends on output stream configurations.
///
/// For video, the shape is ``[time, channel, height, width]``, and
/// the ``dtype`` is ``torch.uint8``.
torch::Tensor frames;
///
/// Presentation time stamp of the first frame, in second.
double pts;
};
} // namespace ffmpeg
} // namespace torchaudio
......@@ -2,10 +2,6 @@ import torchaudio
_STREAM_READER = [
"StreamReader",
"StreamReaderSourceStream",
"StreamReaderSourceAudioStream",
"StreamReaderSourceVideoStream",
"StreamReaderOutputStream",
]
_STREAM_WRITER = [
......
......@@ -83,10 +83,11 @@ def _load_audio(
option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.process_all_packets()
waveform = s.pop_chunks()[0]
if waveform is None:
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
assert waveform is not None
assert chunk is not None
waveform = chunk[0]
if channels_first:
waveform = waveform.T
return waveform, sample_rate
......
......@@ -5,16 +5,21 @@ from typing import BinaryIO, Dict, Iterator, Optional, Tuple, Union
import torch
import torchaudio
from torch.utils._pytree import tree_map
__all__ = [
"StreamReader",
]
@dataclass
class StreamReaderSourceStream:
class SourceStream:
"""The metadata of a source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`.
This class is used when representing streams of media type other than `audio` or `video`.
When source stream is `audio` or `video` type, :class:`StreamReaderSourceAudioStream` and
:class:`StreamReaderSourceVideoStream`, which reports additional media-specific attributes,
When source stream is `audio` or `video` type, :class:`SourceAudioStream` and
:class:`SourceVideoStream`, which reports additional media-specific attributes,
are used respectively.
"""
......@@ -65,12 +70,12 @@ class StreamReaderSourceStream:
@dataclass
class StreamReaderSourceAudioStream(StreamReaderSourceStream):
class SourceAudioStream(SourceStream):
"""The metadata of an audio source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`.
This class is used when representing audio stream.
In addition to the attributes reported by :class:`StreamReaderSourceStream`,
In addition to the attributes reported by :class:`SourceStream`,
the following attributes are reported.
"""
......@@ -81,12 +86,12 @@ class StreamReaderSourceAudioStream(StreamReaderSourceStream):
@dataclass
class StreamReaderSourceVideoStream(StreamReaderSourceStream):
class SourceVideoStream(SourceStream):
"""The metadata of a video source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`.
This class is used when representing video stream.
In addition to the attributes reported by :class:`StreamReaderSourceStream`,
In addition to the attributes reported by :class:`SourceStream`,
the following attributes are reported.
"""
......@@ -127,7 +132,7 @@ def _parse_si(i):
bps = i[_BPS]
metadata = i[_METADATA]
if media_type == "audio":
return StreamReaderSourceAudioStream(
return SourceAudioStream(
media_type=media_type,
codec=codec_name,
codec_long_name=codec_long_name,
......@@ -140,7 +145,7 @@ def _parse_si(i):
num_channels=i[_NUM_CHANNELS],
)
if media_type == "video":
return StreamReaderSourceVideoStream(
return SourceVideoStream(
media_type=media_type,
codec=codec_name,
codec_long_name=codec_long_name,
......@@ -153,7 +158,7 @@ def _parse_si(i):
height=i[_HEIGHT],
frame_rate=i[_FRAME_RATE],
)
return StreamReaderSourceStream(
return SourceStream(
media_type=media_type,
codec=codec_name,
codec_long_name=codec_long_name,
......@@ -166,7 +171,7 @@ def _parse_si(i):
@dataclass
class StreamReaderOutputStream:
class OutputStream:
"""Output stream configured on :class:`StreamReader`,
returned by :meth:`~torchaudio.io.StreamReader.get_out_stream_info`.
"""
......@@ -178,7 +183,7 @@ class StreamReaderOutputStream:
def _parse_oi(i):
return StreamReaderOutputStream(i[0], i[1])
return OutputStream(i[0], i[1])
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
......@@ -206,6 +211,78 @@ def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height:
return ",".join(descs) if descs else None
# Base class for ChunkTensor
# Based off of TrivialTensorViaComposition
# https://github.com/albanD/subclass_zoo/blob/0eeb1d68fb59879029c610bc407f2997ae43ba0a/trivial_tensors.py#L83
class ChunkTensorBase(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, _elem, *_):
return super().__new__(cls, _elem)
@classmethod
def __torch_dispatch__(cls, func, _, args=(), kwargs=None):
def unwrap(t):
return t._elem if isinstance(t, cls) else t
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
@dataclass
class ChunkTensor(ChunkTensorBase):
"""Decoded media frames with metadata.
The instance of this class represents the decoded video/audio frames with
metadata, and the instance itself behave like :py:class:`~torch.Tensor`.
Client codes can pass instance of this class as-if it's
:py:class:`~torch.Tensor` class, or call the methods defined on
:py:class:`~torch.Tensor` class.
Example:
>>> # Define input streams
>>> reader = StreamReader(...)
>>> reader.add_audio_stream(frames_per_chunk=4000, sample_rate=8000)
>>> reader.add_video_stream(frames_per_chunk=7, frame_rate=28)
>>> # Decode the streams and fetch frames
>>> reader.fill_buffer()
>>> audio_chunk, video_chunk = reader.pop_chunks()
>>> # Access metadata
>>> (audio_chunk.pts, video_chunks.pts)
(0.0, 0.0)
>>>
>>> # The second time the PTS is different
>>> reader.fill_buffer()
>>> audio_chunk, video_chunk = reader.pop_chunks()
>>> (audio_chunk.pts, video_chunks.pts)
(0.5, 0.25)
>>> # Call PyTorch ops on chunk
>>> audio_chunk.shape
torch.Size([4000, 2]
>>> power = torch.pow(video_chunk, 2)
>>>
>>> # the result is a plain torch.Tensor class
>>> type(power)
<class 'torch.Tensor'>
>>>
>>> # Metadata is not available on the result
>>> power.pts
AttributeError: 'Tensor' object has no attribute 'pts'
"""
# Keep it private for now
_elem: torch.Tensor
pts: float
"""Presentation time stamp of the first frame in the chunk.
Unit: second.
"""
def _format_doc(**kwargs):
def decorator(obj):
obj.__doc__ = obj.__doc__.format(**kwargs)
......@@ -223,8 +300,8 @@ _frames_per_chunk = """Number of frames returned as one chunk.
_buffer_chunk_size = """Internal buffer size.
When the number of chunks buffered exceeds this number, old frames are
dropped. For example, if `frames_per_chunk` is 5 and `buffer_chunk_size` is
3, then frames older than 15 are dropped.
dropped. For example, if ``frames_per_chunk`` is 5 and ``buffer_chunk_size`` is
3, then frames older than ``15`` are dropped.
Providing ``-1`` disables this behavior.
Default: ``3``."""
......@@ -249,22 +326,28 @@ _decoder_option = """Options passed to decoder.
Mapping from str to str. (Default: ``None``)
To list decoder options for a decoder, you can use
`ffmpeg -h decoder=<DECODER>` command.
``ffmpeg -h decoder=<DECODER>`` command.
|
In addition to decoder-specific options, you can also pass options related
to multithreading. They are effective only if the decoder support them.
If neither of them are provided, StreamReader defaults to single thread.
- ``"threads"``: The number of threads (in str) or the value ``"0"``
to let FFmpeg decides based on its heuristics.
- ``"thread_type"``: Which multithreading method to use.
The valid values are ``"frame"`` or ``"slice"``.
Note that sach decoder supports different set of methods.
If not provided, a default value is used.
- ``"frame"``: Decode more than one frame at once.
Each thread handles one frame.
This will increase decoding delay by one frame per thread
- ``"slice"``: Decode more than one part of a single frame at once.
``"threads"``: The number of threads (in str).
Providing the value ``"0"`` will let FFmpeg decides based on its heuristics.
``"thread_type"``: Which multithreading method to use.
The valid values are ``"frame"`` or ``"slice"``.
Note that each decoder supports different set of methods.
If not provided, a default value is used.
- ``"frame"``: Decode more than one frame at once.
Each thread handles one frame.
This will increase decoding delay by one frame per thread
- ``"slice"``: Decode more than one part of a single frame at once.
|
"""
......@@ -433,7 +516,7 @@ class StreamReader:
"""
return self._be.get_metadata()
def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream:
def get_src_stream_info(self, i: int) -> SourceStream:
"""Get the metadata of source stream
Args:
......@@ -443,7 +526,7 @@ class StreamReader:
"""
return _parse_si(self._be.get_src_stream_info(i))
def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream:
def get_out_stream_info(self, i: int) -> torchaudio.io.OutputStream:
"""Get the metadata of output stream
Args:
......@@ -748,15 +831,21 @@ class StreamReader:
"""Returns true if all the output streams have at least one chunk filled."""
return self._be.is_buffer_ready()
def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]:
def pop_chunks(self) -> Tuple[Optional[ChunkTensor]]:
"""Pop one chunk from all the output stream buffers.
Returns:
Tuple[Optional[Tensor]]:
Tuple[Optional[ChunkTensor]]:
Buffer contents.
If a buffer does not contain any frame, then `None` is returned instead.
"""
return self._be.pop_chunks()
ret = []
for chunk in self._be.pop_chunks():
if chunk is None:
ret.append(None)
else:
ret.append(ChunkTensor(chunk[0], chunk[1]))
return ret
def fill_buffer(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int:
"""Keep processing packets until all buffers have at least one chunk
......@@ -783,7 +872,7 @@ class StreamReader:
def stream(
self, timeout: Optional[float] = None, backoff: float = 10.0
) -> Iterator[Tuple[Optional[torch.Tensor], ...]]:
) -> Iterator[Tuple[Optional[ChunkTensor], ...]]:
"""Return an iterator that generates output tensors
Arguments:
......@@ -794,7 +883,7 @@ class StreamReader:
:py:func:`~StreamReader.process_packet`. (Default: ``10.0``)
Returns:
Iterator[Tuple[Optional[torch.Tensor], ...]]:
Iterator[Tuple[Optional[ChunkTensor], ...]]:
Iterator that yields a tuple of chunks that correspond to the output
streams defined by client code.
If an output stream is exhausted, then the chunk Tensor is substituted
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment