Commit 0dd59e0d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Make StreamReader return PTS (#2975)

Summary:
This commit makes `StreamReader` report PTS (presentation time stamp) of the returned chunk as well.

Example

```python
from torchaudio.io import StreamReader

s = StreamReader(...)
s.add_video_stream(...)
for (video_chunk, ) in s.stream():
    # video_chunk is Torch tensor type but has extra attribute of PTS
    print(video_chunk.pts)  # reports the PTS of the first frame of the video chunk.
```

For the backward compatibility, we introduce a `_ChunkTensor`, that is a composition
of Tensor and metadata, but works like a normal tensor in PyTorch operations.

The implementation of `_ChunkTensor` is based on [TrivialTensorViaComposition](https://github.com/albanD/subclass_zoo/blob/0eeb1d68fb59879029c610bc407f2997ae43ba0a/trivial_tensors.py#L83).

It was also suggested to attach metadata directly to Tensor object,
but the possibility to have the collision on torchaudio's metadata and new attributes introduced in
PyTorch cannot be ignored, so we use Tensor subclass implementation.

If any unexpected issue arise from metadata attribute name collision, client code can
fetch the bare Tensor and continue.

Pull Request resolved: https://github.com/pytorch/audio/pull/2975

Reviewed By: hwangjeff

Differential Revision: D42526945

Pulled By: mthrok

fbshipit-source-id: b4e9422e914ff328421b975120460f3001268f35
parent de628226
...@@ -52,11 +52,17 @@ Methods ...@@ -52,11 +52,17 @@ Methods
Support Structures Support Structures
================== ==================
{%- for item in ["StreamReaderSourceStream", "StreamReaderSourceAudioStream", "StreamReaderSourceVideoStream", "StreamReaderOutputStream"] %} {%- for item in [
"ChunkTensor",
"SourceStream",
"SourceAudioStream",
"SourceVideoStream",
"OutputStream",
] %}
{{ item | underline("-") }} {{ item | underline("-") }}
.. autoclass:: torchaudio.io.{{item}}() .. autoclass:: torchaudio.io._stream_reader.{{item}}()
:members: :members:
{%- endfor %} {%- endfor %}
......
import io
import torch import torch
import torchaudio import torchaudio
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
...@@ -17,12 +19,46 @@ from torchaudio_unittest.common_utils import ( ...@@ -17,12 +19,46 @@ from torchaudio_unittest.common_utils import (
) )
if is_ffmpeg_available(): if is_ffmpeg_available():
from torchaudio.io import ( from torchaudio.io import StreamReader, StreamWriter
StreamReader, from torchaudio.io._stream_reader import ChunkTensor, SourceAudioStream, SourceStream, SourceVideoStream
StreamReaderSourceAudioStream,
StreamReaderSourceStream,
StreamReaderSourceVideoStream, @skipIfNoFFmpeg
) class ChunkTensorTest(TorchaudioTestCase):
def test_chunktensor(self):
"""ChunkTensor serves as a replacement of tensor"""
data = torch.randn((256, 2))
pts = 16.0
c = ChunkTensor(data, pts)
assert c.pts == pts
self.assertEqual(c, data)
# method
sum_ = c.sum()
assert isinstance(sum_, torch.Tensor)
self.assertEqual(sum_, data.sum())
# function form
min_ = torch.min(c)
assert isinstance(min_, torch.Tensor)
self.assertEqual(min_, torch.min(data))
# attribute
t = c.T
assert isinstance(t, torch.Tensor)
self.assertEqual(t, data.T)
# in-place op
c[0] = 0
self.assertEqual(c, data)
# pass to other C++ code
buffer = io.BytesIO()
w = StreamWriter(buffer, format="wav")
w.add_audio_stream(8000, 2)
with w.open():
w.write_audio_chunk(0, c)
################################################################################ ################################################################################
...@@ -109,7 +145,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -109,7 +145,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
base_metadata = {} base_metadata = {}
expected = [ expected = [
StreamReaderSourceVideoStream( SourceVideoStream(
media_type="video", media_type="video",
codec="h264", codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10", codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
...@@ -126,7 +162,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -126,7 +162,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height=180, height=180,
frame_rate=25.0, frame_rate=25.0,
), ),
StreamReaderSourceAudioStream( SourceAudioStream(
media_type="audio", media_type="audio",
codec="aac", codec="aac",
codec_long_name="AAC (Advanced Audio Coding)", codec_long_name="AAC (Advanced Audio Coding)",
...@@ -142,7 +178,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -142,7 +178,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate=8000.0, sample_rate=8000.0,
num_channels=2, num_channels=2,
), ),
StreamReaderSourceStream( SourceStream(
media_type="subtitle", media_type="subtitle",
codec="mov_text", codec="mov_text",
codec_long_name="MOV text", codec_long_name="MOV text",
...@@ -155,7 +191,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -155,7 +191,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
"language": "eng", "language": "eng",
}, },
), ),
StreamReaderSourceVideoStream( SourceVideoStream(
media_type="video", media_type="video",
codec="h264", codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10", codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
...@@ -172,7 +208,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -172,7 +208,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height=270, height=270,
frame_rate=29.97002997002997, frame_rate=29.97002997002997,
), ),
StreamReaderSourceAudioStream( SourceAudioStream(
media_type="audio", media_type="audio",
codec="aac", codec="aac",
codec_long_name="AAC (Advanced Audio Coding)", codec_long_name="AAC (Advanced Audio Coding)",
...@@ -188,7 +224,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -188,7 +224,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate=16000.0, sample_rate=16000.0,
num_channels=2, num_channels=2,
), ),
StreamReaderSourceStream( SourceStream(
media_type="subtitle", media_type="subtitle",
codec="mov_text", codec="mov_text",
codec_long_name="MOV text", codec_long_name="MOV text",
...@@ -605,6 +641,59 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -605,6 +641,59 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
assert video.shape == torch.Size([30, 3, 270, 480]) assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([896, 2]) assert audio.shape == torch.Size([896, 2])
@parameterized.expand([(1,), (3,), (5,), (10,)])
def test_video_pts(self, fpc):
"""PTS values of the first frame are reported in .pts attribute"""
rate, num_frames = 30000 / 1001, 390
ref_pts = [i / rate for i in range(0, num_frames, fpc)]
s = StreamReader(self.get_src())
s.add_video_stream(fpc)
pts = [video.pts for video, in s.stream()]
self.assertEqual(pts, ref_pts)
@parameterized.expand([(256,), (512,), (1024,), (4086,)])
def test_audio_pts(self, fpc):
"""PTS values of the first frame are reported in .pts attribute"""
rate, num_frames = 16000, 208896
ref_pts = [i / rate for i in range(0, num_frames, fpc)]
s = StreamReader(self.get_src())
s.add_audio_stream(fpc, buffer_chunk_size=-1)
pts = [audio.pts for audio, in s.stream()]
self.assertEqual(pts, ref_pts)
def test_pts_unchunked_process_all(self):
"""PTS is zero when loading the entire media with unchunked buffer"""
s = StreamReader(self.get_src())
s.add_audio_stream(-1, buffer_chunk_size=-1)
s.add_video_stream(-1, buffer_chunk_size=-1)
s.process_all_packets()
audio, video = s.pop_chunks()
assert audio.pts == 0.0
assert video.pts == 0.0
assert audio.size(0) == 208896
assert video.size(0) == 390
def test_pts_unchunked(self):
"""PTS grows proportionally to the number of frames decoded"""
s = StreamReader(self.get_src())
s.add_audio_stream(-1, buffer_chunk_size=-1)
s.add_video_stream(-1, buffer_chunk_size=-1)
num_audio_frames, num_video_frames = 0, 0
while num_audio_frames < 208896 and num_video_frames < 390:
s.process_packet()
audio, video = s.pop_chunks()
if audio is None and video is None:
continue
if audio is not None:
assert audio.pts == num_audio_frames / 16000
num_audio_frames += audio.size(0)
if video is not None:
assert video.pts == num_video_frames * 1001 / 30000
num_video_frames += video.size(0)
def _to_fltp(original): def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]""" """Convert Tensor to float32 with value range [-1, 1]"""
......
#pragma once #pragma once
#include <torch/torch.h> #include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
...@@ -21,9 +22,9 @@ class Buffer { ...@@ -21,9 +22,9 @@ class Buffer {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Modifiers // Modifiers
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
virtual void push_frame(AVFrame* frame) = 0; virtual void push_frame(AVFrame* frame, double pts) = 0;
virtual c10::optional<torch::Tensor> pop_chunk() = 0; virtual c10::optional<Chunk> pop_chunk() = 0;
virtual void flush() = 0; virtual void flush() = 0;
}; };
......
...@@ -5,23 +5,33 @@ namespace torchaudio { ...@@ -5,23 +5,33 @@ namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace detail { namespace detail {
ChunkedBuffer::ChunkedBuffer(int frames_per_chunk, int num_chunks) ChunkedBuffer::ChunkedBuffer(
: frames_per_chunk(frames_per_chunk), num_chunks(num_chunks) {} int frames_per_chunk,
int num_chunks,
double frame_duration)
: frame_duration(frame_duration),
frames_per_chunk(frames_per_chunk),
num_chunks(num_chunks) {}
ChunkedAudioBuffer::ChunkedAudioBuffer(int frames_per_chunk, int num_chunks) ChunkedAudioBuffer::ChunkedAudioBuffer(
: ChunkedBuffer(frames_per_chunk, num_chunks) {} int frames_per_chunk,
int num_chunks,
double frame_duration)
: ChunkedBuffer(frames_per_chunk, num_chunks, frame_duration) {}
ChunkedVideoBuffer::ChunkedVideoBuffer( ChunkedVideoBuffer::ChunkedVideoBuffer(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
double frame_duration,
const torch::Device& device_) const torch::Device& device_)
: ChunkedBuffer(frames_per_chunk, num_chunks), device(device_) {} : ChunkedBuffer(frames_per_chunk, num_chunks, frame_duration),
device(device_) {}
bool ChunkedBuffer::is_ready() const { bool ChunkedBuffer::is_ready() const {
return num_buffered_frames >= frames_per_chunk; return num_buffered_frames >= frames_per_chunk;
} }
void ChunkedBuffer::push_tensor(torch::Tensor frame) { void ChunkedBuffer::push_tensor(torch::Tensor frame, double pts_) {
using namespace torch::indexing; using namespace torch::indexing;
// Note: // Note:
// Audio tensors contain multiple frames while video tensors contain only // Audio tensors contain multiple frames while video tensors contain only
...@@ -60,6 +70,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) { ...@@ -60,6 +70,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
num_buffered_frames += append; num_buffered_frames += append;
// frame = frame[append:] // frame = frame[append:]
frame = frame.index({Slice(append)}); frame = frame.index({Slice(append)});
pts_ += double(append) * frame_duration;
} }
// 2. Return if the number of input frames are smaller than the empty buffer. // 2. Return if the number of input frames are smaller than the empty buffer.
...@@ -83,6 +94,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) { ...@@ -83,6 +94,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
int64_t start = i * frames_per_chunk; int64_t start = i * frames_per_chunk;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk] // chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto chunk = frame.index({Slice(start, start + frames_per_chunk)}); auto chunk = frame.index({Slice(start, start + frames_per_chunk)});
double pts_val = pts_ + double(start) * frame_duration;
int64_t chunk_size = chunk.size(0); int64_t chunk_size = chunk.size(0);
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
chunk_size <= frames_per_chunk, chunk_size <= frames_per_chunk,
...@@ -95,6 +107,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) { ...@@ -95,6 +107,7 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
chunk = temp; chunk = temp;
} }
chunks.push_back(chunk); chunks.push_back(chunk);
pts.push_back(pts_val);
num_buffered_frames += chunk_size; num_buffered_frames += chunk_size;
// Trim if num_chunks > 0 // Trim if num_chunks > 0
...@@ -109,26 +122,28 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) { ...@@ -109,26 +122,28 @@ void ChunkedBuffer::push_tensor(torch::Tensor frame) {
} }
} }
c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() { c10::optional<Chunk> ChunkedBuffer::pop_chunk() {
using namespace torch::indexing; using namespace torch::indexing;
if (!num_buffered_frames) { if (!num_buffered_frames) {
return {}; return {};
} }
torch::Tensor ret = chunks.front(); torch::Tensor chunk = chunks.front();
double pts_val = pts.front();
chunks.pop_front(); chunks.pop_front();
pts.pop_front();
if (num_buffered_frames < frames_per_chunk) { if (num_buffered_frames < frames_per_chunk) {
ret = ret.index({Slice(None, num_buffered_frames)}); chunk = chunk.index({Slice(None, num_buffered_frames)});
} }
num_buffered_frames -= ret.size(0); num_buffered_frames -= chunk.size(0);
return c10::optional<torch::Tensor>{ret}; return {Chunk{chunk, pts_val}};
} }
void ChunkedAudioBuffer::push_frame(AVFrame* frame) { void ChunkedAudioBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_audio(frame)); push_tensor(convert_audio(frame), pts_);
} }
void ChunkedVideoBuffer::push_frame(AVFrame* frame) { void ChunkedVideoBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_image(frame, device)); push_tensor(convert_image(frame, device), pts_);
} }
void ChunkedBuffer::flush() { void ChunkedBuffer::flush() {
......
...@@ -13,6 +13,10 @@ namespace detail { ...@@ -13,6 +13,10 @@ namespace detail {
class ChunkedBuffer : public Buffer { class ChunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
// Time stamps corresponding the first frame of each chunk
std::deque<double> pts;
// Duration of one frame, used to recalculate the PTS of audio samples
double frame_duration;
// The number of frames to return as a chunk // The number of frames to return as a chunk
// If <0, then user wants to receive all the frames // If <0, then user wants to receive all the frames
...@@ -25,21 +29,24 @@ class ChunkedBuffer : public Buffer { ...@@ -25,21 +29,24 @@ class ChunkedBuffer : public Buffer {
int64_t num_buffered_frames = 0; int64_t num_buffered_frames = 0;
protected: protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks); ChunkedBuffer(int frames_per_chunk, int num_chunks, double frame_duration);
void push_tensor(torch::Tensor frame); void push_tensor(torch::Tensor frame, double pts);
public: public:
bool is_ready() const override; bool is_ready() const override;
void flush() override; void flush() override;
c10::optional<torch::Tensor> pop_chunk() override; c10::optional<Chunk> pop_chunk() override;
}; };
class ChunkedAudioBuffer : public ChunkedBuffer { class ChunkedAudioBuffer : public ChunkedBuffer {
public: public:
ChunkedAudioBuffer(int frames_per_chunk, int num_chunks); ChunkedAudioBuffer(
int frames_per_chunk,
int num_chunks,
double frame_duration);
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame, double pts) override;
}; };
class ChunkedVideoBuffer : public ChunkedBuffer { class ChunkedVideoBuffer : public ChunkedBuffer {
...@@ -49,9 +56,10 @@ class ChunkedVideoBuffer : public ChunkedBuffer { ...@@ -49,9 +56,10 @@ class ChunkedVideoBuffer : public ChunkedBuffer {
ChunkedVideoBuffer( ChunkedVideoBuffer(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
double frame_duration,
const torch::Device& device); const torch::Device& device);
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame, double pts) override;
}; };
} // namespace detail } // namespace detail
......
...@@ -12,27 +12,30 @@ bool UnchunkedBuffer::is_ready() const { ...@@ -12,27 +12,30 @@ bool UnchunkedBuffer::is_ready() const {
return chunks.size() > 0; return chunks.size() > 0;
} }
void UnchunkedBuffer::push_tensor(const torch::Tensor& t) { void UnchunkedBuffer::push_tensor(const torch::Tensor& t, double pts_) {
if (chunks.size() == 0) {
pts = pts_;
}
chunks.push_back(t); chunks.push_back(t);
} }
void UnchunkedAudioBuffer::push_frame(AVFrame* frame) { void UnchunkedAudioBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_audio(frame)); push_tensor(convert_audio(frame), pts_);
} }
void UnchunkedVideoBuffer::push_frame(AVFrame* frame) { void UnchunkedVideoBuffer::push_frame(AVFrame* frame, double pts_) {
push_tensor(convert_image(frame, device)); push_tensor(convert_image(frame, device), pts_);
} }
c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() { c10::optional<Chunk> UnchunkedBuffer::pop_chunk() {
if (chunks.size() == 0) { if (chunks.size() == 0) {
return {}; return {};
} }
auto ret = auto frames =
torch::cat(std::vector<torch::Tensor>{chunks.begin(), chunks.end()}, 0); torch::cat(std::vector<torch::Tensor>{chunks.begin(), chunks.end()}, 0);
chunks.clear(); chunks.clear();
return {ret}; return {Chunk{frames, pts}};
} }
void UnchunkedBuffer::flush() { void UnchunkedBuffer::flush() {
......
...@@ -16,19 +16,20 @@ namespace detail { ...@@ -16,19 +16,20 @@ namespace detail {
class UnchunkedBuffer : public Buffer { class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
double pts = -1.;
protected: protected:
void push_tensor(const torch::Tensor& t); void push_tensor(const torch::Tensor& t, double pts);
public: public:
bool is_ready() const override; bool is_ready() const override;
c10::optional<torch::Tensor> pop_chunk() override; c10::optional<Chunk> pop_chunk() override;
void flush() override; void flush() override;
}; };
class UnchunkedAudioBuffer : public UnchunkedBuffer { class UnchunkedAudioBuffer : public UnchunkedBuffer {
public: public:
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame, double pts) override;
}; };
class UnchunkedVideoBuffer : public UnchunkedBuffer { class UnchunkedVideoBuffer : public UnchunkedBuffer {
...@@ -37,7 +38,7 @@ class UnchunkedVideoBuffer : public UnchunkedBuffer { ...@@ -37,7 +38,7 @@ class UnchunkedVideoBuffer : public UnchunkedBuffer {
public: public:
explicit UnchunkedVideoBuffer(const torch::Device& device); explicit UnchunkedVideoBuffer(const torch::Device& device);
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame, double pts) override;
}; };
} // namespace detail } // namespace detail
......
...@@ -11,6 +11,7 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -11,6 +11,7 @@ std::unique_ptr<Buffer> get_buffer(
AVMediaType type, AVMediaType type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
double frame_duration,
const torch::Device& device) { const torch::Device& device) {
TORCH_CHECK( TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1, frames_per_chunk > 0 || frames_per_chunk == -1,
...@@ -31,11 +32,11 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -31,11 +32,11 @@ std::unique_ptr<Buffer> get_buffer(
// Chunked Mode // Chunked Mode
if (frames_per_chunk > 0) { if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) { if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(new detail::ChunkedAudioBuffer(
new detail::ChunkedAudioBuffer(frames_per_chunk, num_chunks)); frames_per_chunk, num_chunks, frame_duration));
} else { } else {
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(new detail::ChunkedVideoBuffer(
new detail::ChunkedVideoBuffer(frames_per_chunk, num_chunks, device)); frames_per_chunk, num_chunks, frame_duration, device));
} }
} else { // unchunked mode } else { // unchunked mode
if (type == AVMEDIA_TYPE_AUDIO) { if (type == AVMEDIA_TYPE_AUDIO) {
...@@ -91,10 +92,12 @@ Sink::Sink( ...@@ -91,10 +92,12 @@ Sink::Sink(
filter_description(filter_description_.value_or( filter_description(filter_description_.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")), codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
filter(get_filter_graph(input_time_base_, codecpar_, filter_description)), filter(get_filter_graph(input_time_base_, codecpar_, filter_description)),
output_time_base(filter->get_output_timebase()),
buffer(get_buffer( buffer(get_buffer(
codecpar_->codec_type, codecpar_->codec_type,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
double(output_time_base.num) / output_time_base.den,
device)) {} device)) {}
// 0: some kind of success // 0: some kind of success
...@@ -109,7 +112,9 @@ int Sink::process_frame(AVFrame* pFrame) { ...@@ -109,7 +112,9 @@ int Sink::process_frame(AVFrame* pFrame) {
return 0; return 0;
} }
if (ret >= 0) { if (ret >= 0) {
buffer->push_frame(frame); double pts =
double(frame->pts * output_time_base.num) / output_time_base.den;
buffer->push_frame(frame, pts);
} }
av_frame_unref(frame); av_frame_unref(frame);
} }
......
...@@ -15,6 +15,8 @@ class Sink { ...@@ -15,6 +15,8 @@ class Sink {
AVCodecParameters* codecpar; AVCodecParameters* codecpar;
std::string filter_description; std::string filter_description;
std::unique_ptr<FilterGraph> filter; std::unique_ptr<FilterGraph> filter;
// time_base of filter graph output, used for PTS calc
AVRational output_time_base;
public: public:
std::unique_ptr<Buffer> buffer; std::unique_ptr<Buffer> buffer;
......
...@@ -151,7 +151,7 @@ int StreamProcessor::send_frame(AVFrame* pFrame) { ...@@ -151,7 +151,7 @@ int StreamProcessor::send_frame(AVFrame* pFrame) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Retrieval // Retrieval
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
c10::optional<torch::Tensor> StreamProcessor::pop_chunk(KeyType key) { c10::optional<Chunk> StreamProcessor::pop_chunk(KeyType key) {
return sinks.at(key).buffer->pop_chunk(); return sinks.at(key).buffer->pop_chunk();
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/decoder.h> #include <torchaudio/csrc/ffmpeg/stream_reader/decoder.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/sink.h> #include <torchaudio/csrc/ffmpeg/stream_reader/sink.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <map> #include <map>
namespace torchaudio { namespace torchaudio {
...@@ -95,7 +96,7 @@ class StreamProcessor { ...@@ -95,7 +96,7 @@ class StreamProcessor {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
public: public:
// Get the chunk from the given filter result // Get the chunk from the given filter result
c10::optional<torch::Tensor> pop_chunk(KeyType key); c10::optional<Chunk> pop_chunk(KeyType key);
}; };
} // namespace ffmpeg } // namespace ffmpeg
......
...@@ -385,8 +385,21 @@ int StreamReader::drain() { ...@@ -385,8 +385,21 @@ int StreamReader::drain() {
std::vector<c10::optional<torch::Tensor>> StreamReader::pop_chunks() { std::vector<c10::optional<torch::Tensor>> StreamReader::pop_chunks() {
std::vector<c10::optional<torch::Tensor>> ret; std::vector<c10::optional<torch::Tensor>> ret;
for (auto& c : pop_chunks_with_metadata()) {
if (c) {
ret.emplace_back(c->frames);
} else {
ret.emplace_back();
}
}
return ret;
}
std::vector<c10::optional<Chunk>> StreamReader::pop_chunks_with_metadata() {
std::vector<c10::optional<Chunk>> ret;
ret.reserve(num_out_streams());
for (auto& i : stream_indices) { for (auto& i : stream_indices) {
ret.push_back(processors[i.first]->pop_chunk(i.second)); ret.emplace_back(processors[i.first]->pop_chunk(i.second));
} }
return ret; return ret;
} }
......
...@@ -263,6 +263,10 @@ class StreamReader { ...@@ -263,6 +263,10 @@ class StreamReader {
std::vector<c10::optional<torch::Tensor>> pop_chunks(); std::vector<c10::optional<torch::Tensor>> pop_chunks();
///@} ///@}
/// Pop one chunk from each output stream if it is available.
/// TODO: merge this to pop_chunks
std::vector<c10::optional<Chunk>> pop_chunks_with_metadata();
}; };
} // namespace ffmpeg } // namespace ffmpeg
......
...@@ -103,5 +103,18 @@ int64_t StreamReaderBinding::fill_buffer( ...@@ -103,5 +103,18 @@ int64_t StreamReaderBinding::fill_buffer(
return 0; return 0;
} }
std::vector<c10::optional<ChunkData>> StreamReaderBinding::pop_chunks() {
std::vector<c10::optional<ChunkData>> ret;
ret.reserve(static_cast<size_t>(num_out_streams()));
for (auto& c : StreamReader::pop_chunks_with_metadata()) {
if (c) {
ret.emplace_back(std::forward_as_tuple(c->frames, c->pts));
} else {
ret.emplace_back();
}
}
return ret;
}
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -61,6 +61,8 @@ using OutInfo = std::tuple< ...@@ -61,6 +61,8 @@ using OutInfo = std::tuple<
std::string // filter description std::string // filter description
>; >;
using ChunkData = std::tuple<torch::Tensor, double>;
// Structure to implement wrapper API around StreamReader, which is more // Structure to implement wrapper API around StreamReader, which is more
// suitable for Binding the code (i.e. it receives/returns pritimitves) // suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public StreamReader, struct StreamReaderBinding : public StreamReader,
...@@ -78,6 +80,8 @@ struct StreamReaderBinding : public StreamReader, ...@@ -78,6 +80,8 @@ struct StreamReaderBinding : public StreamReader,
int64_t fill_buffer( int64_t fill_buffer(
const c10::optional<double>& timeout = c10::optional<double>(), const c10::optional<double>& timeout = c10::optional<double>(),
const double backoff = 10.); const double backoff = 10.);
std::vector<c10::optional<ChunkData>> pop_chunks();
}; };
} // namespace ffmpeg } // namespace ffmpeg
......
...@@ -104,5 +104,20 @@ struct OutputStreamInfo { ...@@ -104,5 +104,20 @@ struct OutputStreamInfo {
std::string filter_description; std::string filter_description;
}; };
/// Stores decoded frames and metadata
struct Chunk {
/// Audio/video frames.
///
/// For audio, the shape is ``[time, num_channels]``, and the ``dtype``
/// depends on output stream configurations.
///
/// For video, the shape is ``[time, channel, height, width]``, and
/// the ``dtype`` is ``torch.uint8``.
torch::Tensor frames;
///
/// Presentation time stamp of the first frame, in second.
double pts;
};
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -2,10 +2,6 @@ import torchaudio ...@@ -2,10 +2,6 @@ import torchaudio
_STREAM_READER = [ _STREAM_READER = [
"StreamReader", "StreamReader",
"StreamReaderSourceStream",
"StreamReaderSourceAudioStream",
"StreamReaderSourceVideoStream",
"StreamReaderOutputStream",
] ]
_STREAM_WRITER = [ _STREAM_WRITER = [
......
...@@ -83,10 +83,11 @@ def _load_audio( ...@@ -83,10 +83,11 @@ def _load_audio(
option: Dict[str, str] = {} option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option) s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.process_all_packets() s.process_all_packets()
waveform = s.pop_chunks()[0] chunk = s.pop_chunks()[0]
if waveform is None: if chunk is None:
raise RuntimeError("Failed to decode audio.") raise RuntimeError("Failed to decode audio.")
assert waveform is not None assert chunk is not None
waveform = chunk[0]
if channels_first: if channels_first:
waveform = waveform.T waveform = waveform.T
return waveform, sample_rate return waveform, sample_rate
......
...@@ -5,16 +5,21 @@ from typing import BinaryIO, Dict, Iterator, Optional, Tuple, Union ...@@ -5,16 +5,21 @@ from typing import BinaryIO, Dict, Iterator, Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
from torch.utils._pytree import tree_map
__all__ = [
"StreamReader",
]
@dataclass @dataclass
class StreamReaderSourceStream: class SourceStream:
"""The metadata of a source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`. """The metadata of a source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`.
This class is used when representing streams of media type other than `audio` or `video`. This class is used when representing streams of media type other than `audio` or `video`.
When source stream is `audio` or `video` type, :class:`StreamReaderSourceAudioStream` and When source stream is `audio` or `video` type, :class:`SourceAudioStream` and
:class:`StreamReaderSourceVideoStream`, which reports additional media-specific attributes, :class:`SourceVideoStream`, which reports additional media-specific attributes,
are used respectively. are used respectively.
""" """
...@@ -65,12 +70,12 @@ class StreamReaderSourceStream: ...@@ -65,12 +70,12 @@ class StreamReaderSourceStream:
@dataclass @dataclass
class StreamReaderSourceAudioStream(StreamReaderSourceStream): class SourceAudioStream(SourceStream):
"""The metadata of an audio source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`. """The metadata of an audio source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`.
This class is used when representing audio stream. This class is used when representing audio stream.
In addition to the attributes reported by :class:`StreamReaderSourceStream`, In addition to the attributes reported by :class:`SourceStream`,
the following attributes are reported. the following attributes are reported.
""" """
...@@ -81,12 +86,12 @@ class StreamReaderSourceAudioStream(StreamReaderSourceStream): ...@@ -81,12 +86,12 @@ class StreamReaderSourceAudioStream(StreamReaderSourceStream):
@dataclass @dataclass
class StreamReaderSourceVideoStream(StreamReaderSourceStream): class SourceVideoStream(SourceStream):
"""The metadata of a video source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`. """The metadata of a video source stream, returned by :meth:`~torchaudio.io.StreamReader.get_src_stream_info`.
This class is used when representing video stream. This class is used when representing video stream.
In addition to the attributes reported by :class:`StreamReaderSourceStream`, In addition to the attributes reported by :class:`SourceStream`,
the following attributes are reported. the following attributes are reported.
""" """
...@@ -127,7 +132,7 @@ def _parse_si(i): ...@@ -127,7 +132,7 @@ def _parse_si(i):
bps = i[_BPS] bps = i[_BPS]
metadata = i[_METADATA] metadata = i[_METADATA]
if media_type == "audio": if media_type == "audio":
return StreamReaderSourceAudioStream( return SourceAudioStream(
media_type=media_type, media_type=media_type,
codec=codec_name, codec=codec_name,
codec_long_name=codec_long_name, codec_long_name=codec_long_name,
...@@ -140,7 +145,7 @@ def _parse_si(i): ...@@ -140,7 +145,7 @@ def _parse_si(i):
num_channels=i[_NUM_CHANNELS], num_channels=i[_NUM_CHANNELS],
) )
if media_type == "video": if media_type == "video":
return StreamReaderSourceVideoStream( return SourceVideoStream(
media_type=media_type, media_type=media_type,
codec=codec_name, codec=codec_name,
codec_long_name=codec_long_name, codec_long_name=codec_long_name,
...@@ -153,7 +158,7 @@ def _parse_si(i): ...@@ -153,7 +158,7 @@ def _parse_si(i):
height=i[_HEIGHT], height=i[_HEIGHT],
frame_rate=i[_FRAME_RATE], frame_rate=i[_FRAME_RATE],
) )
return StreamReaderSourceStream( return SourceStream(
media_type=media_type, media_type=media_type,
codec=codec_name, codec=codec_name,
codec_long_name=codec_long_name, codec_long_name=codec_long_name,
...@@ -166,7 +171,7 @@ def _parse_si(i): ...@@ -166,7 +171,7 @@ def _parse_si(i):
@dataclass @dataclass
class StreamReaderOutputStream: class OutputStream:
"""Output stream configured on :class:`StreamReader`, """Output stream configured on :class:`StreamReader`,
returned by :meth:`~torchaudio.io.StreamReader.get_out_stream_info`. returned by :meth:`~torchaudio.io.StreamReader.get_out_stream_info`.
""" """
...@@ -178,7 +183,7 @@ class StreamReaderOutputStream: ...@@ -178,7 +183,7 @@ class StreamReaderOutputStream:
def _parse_oi(i): def _parse_oi(i):
return StreamReaderOutputStream(i[0], i[1]) return OutputStream(i[0], i[1])
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]): def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
...@@ -206,6 +211,78 @@ def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height: ...@@ -206,6 +211,78 @@ def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height:
return ",".join(descs) if descs else None return ",".join(descs) if descs else None
# Base class for ChunkTensor
# Based off of TrivialTensorViaComposition
# https://github.com/albanD/subclass_zoo/blob/0eeb1d68fb59879029c610bc407f2997ae43ba0a/trivial_tensors.py#L83
class ChunkTensorBase(torch.Tensor):
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, _elem, *_):
return super().__new__(cls, _elem)
@classmethod
def __torch_dispatch__(cls, func, _, args=(), kwargs=None):
def unwrap(t):
return t._elem if isinstance(t, cls) else t
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
@dataclass
class ChunkTensor(ChunkTensorBase):
"""Decoded media frames with metadata.
The instance of this class represents the decoded video/audio frames with
metadata, and the instance itself behave like :py:class:`~torch.Tensor`.
Client codes can pass instance of this class as-if it's
:py:class:`~torch.Tensor` class, or call the methods defined on
:py:class:`~torch.Tensor` class.
Example:
>>> # Define input streams
>>> reader = StreamReader(...)
>>> reader.add_audio_stream(frames_per_chunk=4000, sample_rate=8000)
>>> reader.add_video_stream(frames_per_chunk=7, frame_rate=28)
>>> # Decode the streams and fetch frames
>>> reader.fill_buffer()
>>> audio_chunk, video_chunk = reader.pop_chunks()
>>> # Access metadata
>>> (audio_chunk.pts, video_chunks.pts)
(0.0, 0.0)
>>>
>>> # The second time the PTS is different
>>> reader.fill_buffer()
>>> audio_chunk, video_chunk = reader.pop_chunks()
>>> (audio_chunk.pts, video_chunks.pts)
(0.5, 0.25)
>>> # Call PyTorch ops on chunk
>>> audio_chunk.shape
torch.Size([4000, 2]
>>> power = torch.pow(video_chunk, 2)
>>>
>>> # the result is a plain torch.Tensor class
>>> type(power)
<class 'torch.Tensor'>
>>>
>>> # Metadata is not available on the result
>>> power.pts
AttributeError: 'Tensor' object has no attribute 'pts'
"""
# Keep it private for now
_elem: torch.Tensor
pts: float
"""Presentation time stamp of the first frame in the chunk.
Unit: second.
"""
def _format_doc(**kwargs): def _format_doc(**kwargs):
def decorator(obj): def decorator(obj):
obj.__doc__ = obj.__doc__.format(**kwargs) obj.__doc__ = obj.__doc__.format(**kwargs)
...@@ -223,8 +300,8 @@ _frames_per_chunk = """Number of frames returned as one chunk. ...@@ -223,8 +300,8 @@ _frames_per_chunk = """Number of frames returned as one chunk.
_buffer_chunk_size = """Internal buffer size. _buffer_chunk_size = """Internal buffer size.
When the number of chunks buffered exceeds this number, old frames are When the number of chunks buffered exceeds this number, old frames are
dropped. For example, if `frames_per_chunk` is 5 and `buffer_chunk_size` is dropped. For example, if ``frames_per_chunk`` is 5 and ``buffer_chunk_size`` is
3, then frames older than 15 are dropped. 3, then frames older than ``15`` are dropped.
Providing ``-1`` disables this behavior. Providing ``-1`` disables this behavior.
Default: ``3``.""" Default: ``3``."""
...@@ -249,22 +326,28 @@ _decoder_option = """Options passed to decoder. ...@@ -249,22 +326,28 @@ _decoder_option = """Options passed to decoder.
Mapping from str to str. (Default: ``None``) Mapping from str to str. (Default: ``None``)
To list decoder options for a decoder, you can use To list decoder options for a decoder, you can use
`ffmpeg -h decoder=<DECODER>` command. ``ffmpeg -h decoder=<DECODER>`` command.
|
In addition to decoder-specific options, you can also pass options related In addition to decoder-specific options, you can also pass options related
to multithreading. They are effective only if the decoder support them. to multithreading. They are effective only if the decoder support them.
If neither of them are provided, StreamReader defaults to single thread. If neither of them are provided, StreamReader defaults to single thread.
- ``"threads"``: The number of threads (in str) or the value ``"0"`` ``"threads"``: The number of threads (in str).
to let FFmpeg decides based on its heuristics. Providing the value ``"0"`` will let FFmpeg decides based on its heuristics.
- ``"thread_type"``: Which multithreading method to use.
The valid values are ``"frame"`` or ``"slice"``. ``"thread_type"``: Which multithreading method to use.
Note that sach decoder supports different set of methods. The valid values are ``"frame"`` or ``"slice"``.
If not provided, a default value is used. Note that each decoder supports different set of methods.
- ``"frame"``: Decode more than one frame at once. If not provided, a default value is used.
Each thread handles one frame.
This will increase decoding delay by one frame per thread - ``"frame"``: Decode more than one frame at once.
- ``"slice"``: Decode more than one part of a single frame at once. Each thread handles one frame.
This will increase decoding delay by one frame per thread
- ``"slice"``: Decode more than one part of a single frame at once.
|
""" """
...@@ -433,7 +516,7 @@ class StreamReader: ...@@ -433,7 +516,7 @@ class StreamReader:
""" """
return self._be.get_metadata() return self._be.get_metadata()
def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream: def get_src_stream_info(self, i: int) -> SourceStream:
"""Get the metadata of source stream """Get the metadata of source stream
Args: Args:
...@@ -443,7 +526,7 @@ class StreamReader: ...@@ -443,7 +526,7 @@ class StreamReader:
""" """
return _parse_si(self._be.get_src_stream_info(i)) return _parse_si(self._be.get_src_stream_info(i))
def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream: def get_out_stream_info(self, i: int) -> torchaudio.io.OutputStream:
"""Get the metadata of output stream """Get the metadata of output stream
Args: Args:
...@@ -748,15 +831,21 @@ class StreamReader: ...@@ -748,15 +831,21 @@ class StreamReader:
"""Returns true if all the output streams have at least one chunk filled.""" """Returns true if all the output streams have at least one chunk filled."""
return self._be.is_buffer_ready() return self._be.is_buffer_ready()
def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]: def pop_chunks(self) -> Tuple[Optional[ChunkTensor]]:
"""Pop one chunk from all the output stream buffers. """Pop one chunk from all the output stream buffers.
Returns: Returns:
Tuple[Optional[Tensor]]: Tuple[Optional[ChunkTensor]]:
Buffer contents. Buffer contents.
If a buffer does not contain any frame, then `None` is returned instead. If a buffer does not contain any frame, then `None` is returned instead.
""" """
return self._be.pop_chunks() ret = []
for chunk in self._be.pop_chunks():
if chunk is None:
ret.append(None)
else:
ret.append(ChunkTensor(chunk[0], chunk[1]))
return ret
def fill_buffer(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int: def fill_buffer(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int:
"""Keep processing packets until all buffers have at least one chunk """Keep processing packets until all buffers have at least one chunk
...@@ -783,7 +872,7 @@ class StreamReader: ...@@ -783,7 +872,7 @@ class StreamReader:
def stream( def stream(
self, timeout: Optional[float] = None, backoff: float = 10.0 self, timeout: Optional[float] = None, backoff: float = 10.0
) -> Iterator[Tuple[Optional[torch.Tensor], ...]]: ) -> Iterator[Tuple[Optional[ChunkTensor], ...]]:
"""Return an iterator that generates output tensors """Return an iterator that generates output tensors
Arguments: Arguments:
...@@ -794,7 +883,7 @@ class StreamReader: ...@@ -794,7 +883,7 @@ class StreamReader:
:py:func:`~StreamReader.process_packet`. (Default: ``10.0``) :py:func:`~StreamReader.process_packet`. (Default: ``10.0``)
Returns: Returns:
Iterator[Tuple[Optional[torch.Tensor], ...]]: Iterator[Tuple[Optional[ChunkTensor], ...]]:
Iterator that yields a tuple of chunks that correspond to the output Iterator that yields a tuple of chunks that correspond to the output
streams defined by client code. streams defined by client code.
If an output stream is exhausted, then the chunk Tensor is substituted If an output stream is exhausted, then the chunk Tensor is substituted
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment