Commit 146195d8 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Include format information after filter (#3155)

Summary:
This commit adds fields to OutputStream, which shows the result
of fitlers, such as width and height after filtering.

Before

```
OutputStream(
    source_index=0,
    filter_description='fps=3,scale=width=320:height=320,format=pix_fmts=gray')
```

After

```
OutputVideoStream(
    source_index=0,
    filter_description='fps=3,scale=width=320:height=320,format=pix_fmts=gray',
    media_type='video',
    format='gray',
    width=320,
    height=320,
    frame_rate=3.0)
```

Pull Request resolved: https://github.com/pytorch/audio/pull/3155

Reviewed By: nateanl

Differential Revision: D43882399

Pulled By: mthrok

fbshipit-source-id: 620676b1a06f293fdd56de8203a11120f228fa2d
parent 8d2f6f8d
...@@ -61,6 +61,8 @@ Support Structures ...@@ -61,6 +61,8 @@ Support Structures
"SourceAudioStream", "SourceAudioStream",
"SourceVideoStream", "SourceVideoStream",
"OutputStream", "OutputStream",
"OutputAudioStream",
"OutputVideoStream",
] %} ] %}
{{ item | underline("~") }} {{ item | underline("~") }}
......
...@@ -20,7 +20,14 @@ from torchaudio_unittest.common_utils import ( ...@@ -20,7 +20,14 @@ from torchaudio_unittest.common_utils import (
if is_ffmpeg_available(): if is_ffmpeg_available():
from torchaudio.io import StreamReader, StreamWriter from torchaudio.io import StreamReader, StreamWriter
from torchaudio.io._stream_reader import ChunkTensor, SourceAudioStream, SourceStream, SourceVideoStream from torchaudio.io._stream_reader import (
ChunkTensor,
OutputAudioStream,
OutputVideoStream,
SourceAudioStream,
SourceStream,
SourceVideoStream,
)
@skipIfNoFFmpeg @skipIfNoFFmpeg
...@@ -238,6 +245,81 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -238,6 +245,81 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
output = [s.get_src_stream_info(i) for i in range(6)] output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output assert expected == output
def test_output_info(self):
s = StreamReader(self.get_src())
s.add_audio_stream(-1)
s.add_audio_stream(-1, filter_desc="aresample=8000")
s.add_audio_stream(-1, filter_desc="aformat=sample_fmts=s16p")
s.add_video_stream(-1)
s.add_video_stream(-1, filter_desc="fps=10")
s.add_video_stream(-1, filter_desc="format=rgb24")
s.add_video_stream(-1, filter_desc="scale=w=160:h=90")
expected = [
OutputAudioStream(
source_index=4,
filter_description="anull",
media_type="audio",
format="fltp",
sample_rate=16000.0,
num_channels=2,
),
OutputAudioStream(
source_index=4,
filter_description="aresample=8000",
media_type="audio",
format="fltp",
sample_rate=8000.0,
num_channels=2,
),
OutputAudioStream(
source_index=4,
filter_description="aformat=sample_fmts=s16p",
media_type="audio",
format="s16p",
sample_rate=16000.0,
num_channels=2,
),
OutputVideoStream(
source_index=3,
filter_description="null",
media_type="video",
format="yuv420p",
width=480,
height=270,
frame_rate=30000 / 1001,
),
OutputVideoStream(
source_index=3,
filter_description="fps=10",
media_type="video",
format="yuv420p",
width=480,
height=270,
frame_rate=10,
),
OutputVideoStream(
source_index=3,
filter_description="format=rgb24",
media_type="video",
format="rgb24",
width=480,
height=270,
frame_rate=30000 / 1001,
),
OutputVideoStream(
source_index=3,
filter_description="scale=w=160:h=90",
media_type="video",
format="yuv420p",
width=160,
height=90,
frame_rate=30000 / 1001,
),
]
output = [s.get_out_stream_info(i) for i in range(s.num_out_streams)]
assert expected == output
def test_id3tag(self): def test_id3tag(self):
"""get_metadata method can fetch id3tag properly""" """get_metadata method can fetch id3tag properly"""
s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3")) s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3"))
......
...@@ -41,6 +41,7 @@ std::string get_audio_src_args( ...@@ -41,6 +41,7 @@ std::string get_audio_src_args(
std::string get_video_src_args( std::string get_video_src_args(
AVPixelFormat format, AVPixelFormat format,
AVRational time_base, AVRational time_base,
AVRational frame_rate,
int width, int width,
int height, int height,
AVRational sample_aspect_ratio) { AVRational sample_aspect_ratio) {
...@@ -48,12 +49,14 @@ std::string get_video_src_args( ...@@ -48,12 +49,14 @@ std::string get_video_src_args(
std::snprintf( std::snprintf(
args, args,
sizeof(args), sizeof(args),
"video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:pixel_aspect=%d/%d", "video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:frame_rate=%d/%d:pixel_aspect=%d/%d",
width, width,
height, height,
av_get_pix_fmt_name(format), av_get_pix_fmt_name(format),
time_base.num, time_base.num,
time_base.den, time_base.den,
frame_rate.num,
frame_rate.den,
sample_aspect_ratio.num, sample_aspect_ratio.num,
sample_aspect_ratio.den); sample_aspect_ratio.den);
return std::string(args); return std::string(args);
...@@ -76,13 +79,14 @@ void FilterGraph::add_audio_src( ...@@ -76,13 +79,14 @@ void FilterGraph::add_audio_src(
void FilterGraph::add_video_src( void FilterGraph::add_video_src(
AVPixelFormat format, AVPixelFormat format,
AVRational time_base, AVRational time_base,
AVRational frame_rate,
int width, int width,
int height, int height,
AVRational sample_aspect_ratio) { AVRational sample_aspect_ratio) {
TORCH_CHECK( TORCH_CHECK(
media_type == AVMEDIA_TYPE_VIDEO, "The filter graph is not video type."); media_type == AVMEDIA_TYPE_VIDEO, "The filter graph is not video type.");
std::string args = std::string args = get_video_src_args(
get_video_src_args(format, time_base, width, height, sample_aspect_ratio); format, time_base, frame_rate, width, height, sample_aspect_ratio);
add_src(args); add_src(args);
} }
...@@ -164,7 +168,7 @@ void FilterGraph::add_process(const std::string& filter_description) { ...@@ -164,7 +168,7 @@ void FilterGraph::add_process(const std::string& filter_description) {
void FilterGraph::create_filter() { void FilterGraph::create_filter() {
int ret = avfilter_graph_config(pFilterGraph, nullptr); int ret = avfilter_graph_config(pFilterGraph, nullptr);
TORCH_CHECK(ret >= 0, "Failed to configure the graph: " + av_err2string(ret)); TORCH_CHECK(ret >= 0, "Failed to configure the graph: " + av_err2string(ret));
// char* desc = avfilter_graph_dump(pFilterGraph.get(), NULL); // char* desc = avfilter_graph_dump(pFilterGraph, NULL);
// std::cerr << "Filter created:\n" << desc << std::endl; // std::cerr << "Filter created:\n" << desc << std::endl;
// av_free(static_cast<void*>(desc)); // av_free(static_cast<void*>(desc));
} }
...@@ -177,22 +181,26 @@ AVRational FilterGraph::get_output_timebase() const { ...@@ -177,22 +181,26 @@ AVRational FilterGraph::get_output_timebase() const {
return buffersink_ctx->inputs[0]->time_base; return buffersink_ctx->inputs[0]->time_base;
} }
int FilterGraph::get_output_sample_rate() const { FilterGraphOutputInfo FilterGraph::get_output_info() const {
TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized."); TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized.");
return buffersink_ctx->inputs[0]->sample_rate; AVFilterLink* l = buffersink_ctx->inputs[0];
} FilterGraphOutputInfo ret{};
ret.type = l->type;
int FilterGraph::get_output_channels() const { ret.format = l->format;
TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized."); if (l->type == AVMEDIA_TYPE_AUDIO) {
// Since FFmpeg 5.1 ret.sample_rate = l->sample_rate;
// https://github.com/FFmpeg/FFmpeg/blob/release/5.1/doc/APIchanges#L45-L54
#if LIBAVFILTER_VERSION_MAJOR >= 8 && LIBAVFILTER_VERSION_MINOR >= 44 #if LIBAVFILTER_VERSION_MAJOR >= 8 && LIBAVFILTER_VERSION_MINOR >= 44
return buffersink_ctx->inputs[0]->ch_layout.nb_channels; ret.num_channels = l->ch_layout.nb_channels;
#else #else
// Before FFmpeg 5.1 // Before FFmpeg 5.1
return av_get_channel_layout_nb_channels( ret.num_channels = av_get_channel_layout_nb_channels(l->channel_layout);
buffersink_ctx->inputs[0]->channel_layout);
#endif #endif
} else {
ret.frame_rate = l->frame_rate;
ret.height = l->h;
ret.width = l->w;
}
return ret;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -4,6 +4,21 @@ ...@@ -4,6 +4,21 @@
namespace torchaudio { namespace torchaudio {
namespace io { namespace io {
/// Used to report the output formats of filter graph.
struct FilterGraphOutputInfo {
AVMediaType type = AVMEDIA_TYPE_UNKNOWN;
int format = -1;
// Audio
int sample_rate = -1;
int num_channels = -1;
// Video
AVRational frame_rate = {0, 1};
int height = -1;
int width = -1;
};
class FilterGraph { class FilterGraph {
AVMediaType media_type; AVMediaType media_type;
...@@ -37,6 +52,7 @@ class FilterGraph { ...@@ -37,6 +52,7 @@ class FilterGraph {
void add_video_src( void add_video_src(
AVPixelFormat format, AVPixelFormat format,
AVRational time_base, AVRational time_base,
AVRational frame_rate,
int width, int width,
int height, int height,
AVRational sample_aspect_ratio); AVRational sample_aspect_ratio);
...@@ -53,8 +69,7 @@ class FilterGraph { ...@@ -53,8 +69,7 @@ class FilterGraph {
// Query methods // Query methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
[[nodiscard]] AVRational get_output_timebase() const; [[nodiscard]] AVRational get_output_timebase() const;
[[nodiscard]] int get_output_sample_rate() const; [[nodiscard]] FilterGraphOutputInfo get_output_info() const;
[[nodiscard]] int get_output_channels() const;
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Streaming process // Streaming process
......
...@@ -57,8 +57,44 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { ...@@ -57,8 +57,44 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("close", &StreamWriterFileObj::close); .def("close", &StreamWriterFileObj::close);
py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local()) py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local())
.def_readonly("source_index", &OutputStreamInfo::source_index) .def_readonly("source_index", &OutputStreamInfo::source_index)
.def_readonly( .def_readonly("filter_description", &OutputStreamInfo::filter_description)
"filter_description", &OutputStreamInfo::filter_description); .def_property_readonly(
"media_type",
[](const OutputStreamInfo& o) -> std::string {
return av_get_media_type_string(o.media_type);
})
.def_property_readonly(
"format",
[](const OutputStreamInfo& o) -> std::string {
switch (o.media_type) {
case AVMEDIA_TYPE_AUDIO:
return av_get_sample_fmt_name((AVSampleFormat)(o.format));
case AVMEDIA_TYPE_VIDEO:
return av_get_pix_fmt_name((AVPixelFormat)(o.format));
default:
TORCH_INTERNAL_ASSERT(
false,
"FilterGraph is returning unexpected media type: ",
av_get_media_type_string(o.media_type));
}
})
.def_readonly("sample_rate", &OutputStreamInfo::sample_rate)
.def_readonly("num_channels", &OutputStreamInfo::num_channels)
.def_readonly("width", &OutputStreamInfo::width)
.def_readonly("height", &OutputStreamInfo::height)
.def_property_readonly(
"frame_rate", [](const OutputStreamInfo& o) -> double {
if (o.frame_rate.den == 0) {
TORCH_WARN(
o.frame_rate.den,
"Invalid frame rate is found: ",
o.frame_rate.num,
"/",
o.frame_rate.den);
return -1;
}
return static_cast<double>(o.frame_rate.num) / o.frame_rate.den;
});
py::class_<SrcStreamInfo>(m, "SourceStreamInfo", py::module_local()) py::class_<SrcStreamInfo>(m, "SourceStreamInfo", py::module_local())
.def_property_readonly( .def_property_readonly(
"media_type", "media_type",
......
...@@ -50,6 +50,7 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -50,6 +50,7 @@ std::unique_ptr<Buffer> get_buffer(
std::unique_ptr<FilterGraph> get_filter_graph( std::unique_ptr<FilterGraph> get_filter_graph(
AVRational input_time_base, AVRational input_time_base,
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
AVRational frame_rate,
const std::string& filter_description) { const std::string& filter_description) {
auto p = std::make_unique<FilterGraph>(codecpar->codec_type); auto p = std::make_unique<FilterGraph>(codecpar->codec_type);
...@@ -65,6 +66,7 @@ std::unique_ptr<FilterGraph> get_filter_graph( ...@@ -65,6 +66,7 @@ std::unique_ptr<FilterGraph> get_filter_graph(
p->add_video_src( p->add_video_src(
static_cast<AVPixelFormat>(codecpar->format), static_cast<AVPixelFormat>(codecpar->format),
input_time_base, input_time_base,
frame_rate,
codecpar->width, codecpar->width,
codecpar->height, codecpar->height,
codecpar->sample_aspect_ratio); codecpar->sample_aspect_ratio);
...@@ -85,13 +87,19 @@ Sink::Sink( ...@@ -85,13 +87,19 @@ Sink::Sink(
AVCodecParameters* codecpar_, AVCodecParameters* codecpar_,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate_,
const c10::optional<std::string>& filter_description_, const c10::optional<std::string>& filter_description_,
const torch::Device& device) const torch::Device& device)
: input_time_base(input_time_base_), : input_time_base(input_time_base_),
codecpar(codecpar_), codecpar(codecpar_),
frame_rate(frame_rate_),
filter_description(filter_description_.value_or( filter_description(filter_description_.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")), codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
filter(get_filter_graph(input_time_base_, codecpar_, filter_description)), filter(get_filter_graph(
input_time_base_,
codecpar_,
frame_rate,
filter_description)),
output_time_base(filter->get_output_timebase()), output_time_base(filter->get_output_timebase()),
buffer(get_buffer( buffer(get_buffer(
codecpar_->codec_type, codecpar_->codec_type,
...@@ -125,8 +133,13 @@ std::string Sink::get_filter_description() const { ...@@ -125,8 +133,13 @@ std::string Sink::get_filter_description() const {
return filter_description; return filter_description;
} }
FilterGraphOutputInfo Sink::get_filter_output_info() const {
return filter->get_output_info();
}
void Sink::flush() { void Sink::flush() {
filter = get_filter_graph(input_time_base, codecpar, filter_description); filter = get_filter_graph(
input_time_base, codecpar, frame_rate, filter_description);
buffer->flush(); buffer->flush();
} }
......
...@@ -13,6 +13,7 @@ class Sink { ...@@ -13,6 +13,7 @@ class Sink {
// Parameters for recreating FilterGraph // Parameters for recreating FilterGraph
AVRational input_time_base; AVRational input_time_base;
AVCodecParameters* codecpar; AVCodecParameters* codecpar;
AVRational frame_rate;
std::string filter_description; std::string filter_description;
std::unique_ptr<FilterGraph> filter; std::unique_ptr<FilterGraph> filter;
// time_base of filter graph output, used for PTS calc // time_base of filter graph output, used for PTS calc
...@@ -25,10 +26,13 @@ class Sink { ...@@ -25,10 +26,13 @@ class Sink {
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate,
const c10::optional<std::string>& filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device); const torch::Device& device);
std::string get_filter_description() const; [[nodiscard]] std::string get_filter_description() const;
[[nodiscard]] FilterGraphOutputInfo get_filter_output_info() const;
int process_frame(AVFrame* frame); int process_frame(AVFrame* frame);
bool is_buffer_ready() const; bool is_buffer_ready() const;
......
...@@ -20,6 +20,7 @@ StreamProcessor::StreamProcessor( ...@@ -20,6 +20,7 @@ StreamProcessor::StreamProcessor(
KeyType StreamProcessor::add_stream( KeyType StreamProcessor::add_stream(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate,
const c10::optional<std::string>& filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device) { const torch::Device& device) {
switch (stream->codecpar->codec_type) { switch (stream->codecpar->codec_type) {
...@@ -38,6 +39,7 @@ KeyType StreamProcessor::add_stream( ...@@ -38,6 +39,7 @@ KeyType StreamProcessor::add_stream(
stream->codecpar, stream->codecpar,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
frame_rate,
filter_description, filter_description,
device)); device));
return key; return key;
...@@ -60,6 +62,11 @@ std::string StreamProcessor::get_filter_description(KeyType key) const { ...@@ -60,6 +62,11 @@ std::string StreamProcessor::get_filter_description(KeyType key) const {
return sinks.at(key).get_filter_description(); return sinks.at(key).get_filter_description();
} }
FilterGraphOutputInfo StreamProcessor::get_filter_output_info(
KeyType key) const {
return sinks.at(key).get_filter_output_info();
}
bool StreamProcessor::is_buffer_ready() const { bool StreamProcessor::is_buffer_ready() const {
for (const auto& it : sinks) { for (const auto& it : sinks) {
if (!it.second.buffer->is_ready()) { if (!it.second.buffer->is_ready()) {
......
...@@ -59,6 +59,7 @@ class StreamProcessor { ...@@ -59,6 +59,7 @@ class StreamProcessor {
KeyType add_stream( KeyType add_stream(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate,
const c10::optional<std::string>& filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device); const torch::Device& device);
...@@ -72,7 +73,9 @@ class StreamProcessor { ...@@ -72,7 +73,9 @@ class StreamProcessor {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Query methods // Query methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
std::string get_filter_description(KeyType key) const; [[nodiscard]] std::string get_filter_description(KeyType key) const;
[[nodiscard]] FilterGraphOutputInfo get_filter_output_info(KeyType key) const;
bool is_buffer_ready() const; bool is_buffer_ready() const;
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
...@@ -187,11 +187,27 @@ int64_t StreamReader::num_out_streams() const { ...@@ -187,11 +187,27 @@ int64_t StreamReader::num_out_streams() const {
OutputStreamInfo StreamReader::get_out_stream_info(int i) const { OutputStreamInfo StreamReader::get_out_stream_info(int i) const {
validate_output_stream_index(i); validate_output_stream_index(i);
OutputStreamInfo ret;
int i_src = stream_indices[i].first; int i_src = stream_indices[i].first;
KeyType key = stream_indices[i].second; KeyType key = stream_indices[i].second;
FilterGraphOutputInfo info = processors[i_src]->get_filter_output_info(key);
OutputStreamInfo ret;
ret.source_index = i_src; ret.source_index = i_src;
ret.filter_description = processors[i_src]->get_filter_description(key); ret.filter_description = processors[i_src]->get_filter_description(key);
ret.media_type = info.type;
ret.format = info.format;
switch (info.type) {
case AVMEDIA_TYPE_AUDIO:
ret.sample_rate = info.sample_rate;
ret.num_channels = info.num_channels;
break;
case AVMEDIA_TYPE_VIDEO:
ret.width = info.width;
ret.height = info.height;
ret.frame_rate = info.frame_rate;
break;
default:;
}
return ret; return ret;
} }
...@@ -336,8 +352,22 @@ void StreamReader::add_stream( ...@@ -336,8 +352,22 @@ void StreamReader::add_stream(
processors[i]->set_discard_timestamp(seek_timestamp); processors[i]->set_discard_timestamp(seek_timestamp);
} }
stream->discard = AVDISCARD_DEFAULT; stream->discard = AVDISCARD_DEFAULT;
auto frame_rate = [&]() -> AVRational {
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
return AVRational{0, 1};
case AVMEDIA_TYPE_VIDEO:
return av_guess_frame_rate(pFormatContext, stream, nullptr);
default:
TORCH_INTERNAL_ASSERT(
false,
"Unexpected media type is given: ",
av_get_media_type_string(media_type));
}
}();
int key = processors[i]->add_stream( int key = processors[i]->add_stream(
frames_per_chunk, num_chunks, filter_desc, device); frames_per_chunk, num_chunks, frame_rate, filter_desc, device);
stream_indices.push_back(std::make_pair<>(i, key)); stream_indices.push_back(std::make_pair<>(i, key));
} }
......
...@@ -101,9 +101,48 @@ struct SrcStreamInfo { ...@@ -101,9 +101,48 @@ struct SrcStreamInfo {
struct OutputStreamInfo { struct OutputStreamInfo {
/// The index of the input source stream /// The index of the input source stream
int source_index; int source_index;
///
/// The stream media type.
///
/// Please see refer to
/// [the FFmpeg
/// documentation](https://ffmpeg.org/doxygen/4.1/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48)
/// for the available values
///
/// @todo Introduce own enum and get rid of FFmpeg dependency
///
AVMediaType media_type = AVMEDIA_TYPE_UNKNOWN;
int format = -1;
/// Filter graph definition, such as /// Filter graph definition, such as
/// ``"aresample=16000,aformat=sample_fmts=fltp"``. /// ``"aresample=16000,aformat=sample_fmts=fltp"``.
std::string filter_description; std::string filter_description{};
/// @name AUDIO-SPECIFIC MEMBERS
///@{
/// Sample rate
double sample_rate = -1;
/// The number of channels
int num_channels = -1;
///@}
/// @name VIDEO-SPECIFIC MEMBERS
///@{
/// Width
int width = -1;
/// Height
int height = -1;
/// Frame rate
AVRational frame_rate{0, 1};
///@}
}; };
/// Stores decoded frames and metadata /// Stores decoded frames and metadata
......
...@@ -435,6 +435,7 @@ FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { ...@@ -435,6 +435,7 @@ FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
p.add_video_src( p.add_video_src(
src_fmt, src_fmt,
codec_ctx->time_base, codec_ctx->time_base,
codec_ctx->framerate,
codec_ctx->width, codec_ctx->width,
codec_ctx->height, codec_ctx->height,
codec_ctx->sample_aspect_ratio); codec_ctx->sample_aspect_ratio);
......
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import BinaryIO, Dict, Iterator, Optional, Tuple, Union from typing import BinaryIO, Dict, Iterator, Optional, Tuple, TypeVar, Union
import torch import torch
import torchaudio import torchaudio
...@@ -154,6 +154,80 @@ class OutputStream: ...@@ -154,6 +154,80 @@ class OutputStream:
"""Index of the source stream that this output stream is connected.""" """Index of the source stream that this output stream is connected."""
filter_description: str filter_description: str
"""Description of filter graph applied to the source stream.""" """Description of filter graph applied to the source stream."""
media_type: str
"""The type of the stream. ``"audio"`` or ``"video"``."""
format: str
"""Media format. Such as ``"s16"`` and ``"yuv420p"``.
Commonly found audio values are;
- ``"u8"``, ``"u8p"``: Unsigned 8-bit unsigned interger.
- ``"s16"``, ``"s16p"``: 16-bit signed integer.
- ``"s32"``, ``"s32p"``: 32-bit signed integer.
- ``"flt"``, ``"fltp"``: 32-bit floating-point.
.. note::
`p` at the end indicates the format is `planar`.
Channels are grouped together instead of interspersed in memory."""
@dataclass
class OutputAudioStream(OutputStream):
"""Information about an audio output stream configured with
:meth:`~torchaudio.io.StreamReader.add_audio_stream` or
:meth:`~torchaudio.io.StreamReader.add_basic_audio_stream`.
In addition to the attributes reported by :class:`OutputStream`,
the following attributes are reported.
"""
sample_rate: float
"""Sample rate of the audio."""
num_channels: int
"""Number of channels."""
@dataclass
class OutputVideoStream(OutputStream):
"""Information about a video output stream configured with
:meth:`~torchaudio.io.StreamReader.add_video_stream` or
:meth:`~torchaudio.io.StreamReader.add_basic_video_stream`.
In addition to the attributes reported by :class:`OutputStream`,
the following attributes are reported.
"""
width: int
"""Width of the video frame in pixel."""
height: int
"""Height of the video frame in pixel."""
frame_rate: float
"""Frame rate."""
def _parse_oi(i):
media_type = i.media_type
if media_type == "audio":
return OutputAudioStream(
source_index=i.source_index,
filter_description=i.filter_description,
media_type=i.media_type,
format=i.format,
sample_rate=i.sample_rate,
num_channels=i.num_channels,
)
if media_type == "video":
return OutputVideoStream(
source_index=i.source_index,
filter_description=i.filter_description,
media_type=i.media_type,
format=i.format,
width=i.width,
height=i.height,
frame_rate=i.frame_rate,
)
raise ValueError(f"Unexpected media_type: {i.media_type}({i})")
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]): def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
...@@ -351,6 +425,10 @@ _format_video_args = _format_doc( ...@@ -351,6 +425,10 @@ _format_video_args = _format_doc(
) )
InputStreamTypes = TypeVar("InputStream", bound=SourceStream)
OutputStreamTypes = TypeVar("OutputStream", bound=OutputStream)
@torchaudio._extension.fail_if_no_ffmpeg @torchaudio._extension.fail_if_no_ffmpeg
class StreamReader: class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk. """Fetch and decode audio/video streams chunk by chunk.
...@@ -481,29 +559,33 @@ class StreamReader: ...@@ -481,29 +559,33 @@ class StreamReader:
""" """
return self._be.get_metadata() return self._be.get_metadata()
def get_src_stream_info(self, i: int) -> Union[SourceStream, SourceAudioStream, SourceVideoStream]: def get_src_stream_info(self, i: int) -> InputStreamTypes:
"""Get the metadata of source stream """Get the metadata of source stream
Args: Args:
i (int): Stream index. i (int): Stream index.
Returns: Returns:
InputStreamTypes:
Information about the source stream. Information about the source stream.
If the source stream is audio type, then :class:`SourceAudioStream` returned. If the source stream is audio type, then :class:`~torchaudio.io._stream_reader.SourceAudioStream` returned.
If it is video type, then :class:`SourceVideoStream` is returned. If it is video type, then :class:`~torchaudio.io._stream_reader.SourceVideoStream` is returned.
Otherwise :class:`SourceStream` class is returned. Otherwise :class:`~torchaudio.io._stream_reader.SourceStream` class is returned.
""" """
return _parse_si(self._be.get_src_stream_info(i)) return _parse_si(self._be.get_src_stream_info(i))
def get_out_stream_info(self, i: int) -> OutputStream: def get_out_stream_info(self, i: int) -> OutputStreamTypes:
"""Get the metadata of output stream """Get the metadata of output stream
Args: Args:
i (int): Stream index. i (int): Stream index.
Returns: Returns:
OutputStream OutputStreamTypes
Information about the output stream.
If the output stream is audio type, then :class:`~torchaudio.io._stream_reader.OutputAudioStream` returned.
If it is video type, then :class:`~torchaudio.io._stream_reader.OutputVideoStream` is returned.
""" """
info = self._be.get_out_stream_info(i) info = self._be.get_out_stream_info(i)
return OutputStream(info.source_index, info.filter_description) return _parse_oi(info)
def seek(self, timestamp: float, mode: str = "precise"): def seek(self, timestamp: float, mode: str = "precise"):
"""Seek the stream to the given timestamp [second] """Seek the stream to the given timestamp [second]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment