Commit 146195d8 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Include format information after filter (#3155)

Summary:
This commit adds fields to OutputStream, which shows the result
of fitlers, such as width and height after filtering.

Before

```
OutputStream(
    source_index=0,
    filter_description='fps=3,scale=width=320:height=320,format=pix_fmts=gray')
```

After

```
OutputVideoStream(
    source_index=0,
    filter_description='fps=3,scale=width=320:height=320,format=pix_fmts=gray',
    media_type='video',
    format='gray',
    width=320,
    height=320,
    frame_rate=3.0)
```

Pull Request resolved: https://github.com/pytorch/audio/pull/3155

Reviewed By: nateanl

Differential Revision: D43882399

Pulled By: mthrok

fbshipit-source-id: 620676b1a06f293fdd56de8203a11120f228fa2d
parent 8d2f6f8d
......@@ -61,6 +61,8 @@ Support Structures
"SourceAudioStream",
"SourceVideoStream",
"OutputStream",
"OutputAudioStream",
"OutputVideoStream",
] %}
{{ item | underline("~") }}
......
......@@ -20,7 +20,14 @@ from torchaudio_unittest.common_utils import (
if is_ffmpeg_available():
from torchaudio.io import StreamReader, StreamWriter
from torchaudio.io._stream_reader import ChunkTensor, SourceAudioStream, SourceStream, SourceVideoStream
from torchaudio.io._stream_reader import (
ChunkTensor,
OutputAudioStream,
OutputVideoStream,
SourceAudioStream,
SourceStream,
SourceVideoStream,
)
@skipIfNoFFmpeg
......@@ -238,6 +245,81 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output
def test_output_info(self):
s = StreamReader(self.get_src())
s.add_audio_stream(-1)
s.add_audio_stream(-1, filter_desc="aresample=8000")
s.add_audio_stream(-1, filter_desc="aformat=sample_fmts=s16p")
s.add_video_stream(-1)
s.add_video_stream(-1, filter_desc="fps=10")
s.add_video_stream(-1, filter_desc="format=rgb24")
s.add_video_stream(-1, filter_desc="scale=w=160:h=90")
expected = [
OutputAudioStream(
source_index=4,
filter_description="anull",
media_type="audio",
format="fltp",
sample_rate=16000.0,
num_channels=2,
),
OutputAudioStream(
source_index=4,
filter_description="aresample=8000",
media_type="audio",
format="fltp",
sample_rate=8000.0,
num_channels=2,
),
OutputAudioStream(
source_index=4,
filter_description="aformat=sample_fmts=s16p",
media_type="audio",
format="s16p",
sample_rate=16000.0,
num_channels=2,
),
OutputVideoStream(
source_index=3,
filter_description="null",
media_type="video",
format="yuv420p",
width=480,
height=270,
frame_rate=30000 / 1001,
),
OutputVideoStream(
source_index=3,
filter_description="fps=10",
media_type="video",
format="yuv420p",
width=480,
height=270,
frame_rate=10,
),
OutputVideoStream(
source_index=3,
filter_description="format=rgb24",
media_type="video",
format="rgb24",
width=480,
height=270,
frame_rate=30000 / 1001,
),
OutputVideoStream(
source_index=3,
filter_description="scale=w=160:h=90",
media_type="video",
format="yuv420p",
width=160,
height=90,
frame_rate=30000 / 1001,
),
]
output = [s.get_out_stream_info(i) for i in range(s.num_out_streams)]
assert expected == output
def test_id3tag(self):
"""get_metadata method can fetch id3tag properly"""
s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3"))
......
......@@ -41,6 +41,7 @@ std::string get_audio_src_args(
std::string get_video_src_args(
AVPixelFormat format,
AVRational time_base,
AVRational frame_rate,
int width,
int height,
AVRational sample_aspect_ratio) {
......@@ -48,12 +49,14 @@ std::string get_video_src_args(
std::snprintf(
args,
sizeof(args),
"video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:pixel_aspect=%d/%d",
"video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:frame_rate=%d/%d:pixel_aspect=%d/%d",
width,
height,
av_get_pix_fmt_name(format),
time_base.num,
time_base.den,
frame_rate.num,
frame_rate.den,
sample_aspect_ratio.num,
sample_aspect_ratio.den);
return std::string(args);
......@@ -76,13 +79,14 @@ void FilterGraph::add_audio_src(
void FilterGraph::add_video_src(
AVPixelFormat format,
AVRational time_base,
AVRational frame_rate,
int width,
int height,
AVRational sample_aspect_ratio) {
TORCH_CHECK(
media_type == AVMEDIA_TYPE_VIDEO, "The filter graph is not video type.");
std::string args =
get_video_src_args(format, time_base, width, height, sample_aspect_ratio);
std::string args = get_video_src_args(
format, time_base, frame_rate, width, height, sample_aspect_ratio);
add_src(args);
}
......@@ -164,7 +168,7 @@ void FilterGraph::add_process(const std::string& filter_description) {
void FilterGraph::create_filter() {
int ret = avfilter_graph_config(pFilterGraph, nullptr);
TORCH_CHECK(ret >= 0, "Failed to configure the graph: " + av_err2string(ret));
// char* desc = avfilter_graph_dump(pFilterGraph.get(), NULL);
// char* desc = avfilter_graph_dump(pFilterGraph, NULL);
// std::cerr << "Filter created:\n" << desc << std::endl;
// av_free(static_cast<void*>(desc));
}
......@@ -177,22 +181,26 @@ AVRational FilterGraph::get_output_timebase() const {
return buffersink_ctx->inputs[0]->time_base;
}
int FilterGraph::get_output_sample_rate() const {
FilterGraphOutputInfo FilterGraph::get_output_info() const {
TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized.");
return buffersink_ctx->inputs[0]->sample_rate;
}
int FilterGraph::get_output_channels() const {
TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized.");
// Since FFmpeg 5.1
// https://github.com/FFmpeg/FFmpeg/blob/release/5.1/doc/APIchanges#L45-L54
AVFilterLink* l = buffersink_ctx->inputs[0];
FilterGraphOutputInfo ret{};
ret.type = l->type;
ret.format = l->format;
if (l->type == AVMEDIA_TYPE_AUDIO) {
ret.sample_rate = l->sample_rate;
#if LIBAVFILTER_VERSION_MAJOR >= 8 && LIBAVFILTER_VERSION_MINOR >= 44
return buffersink_ctx->inputs[0]->ch_layout.nb_channels;
ret.num_channels = l->ch_layout.nb_channels;
#else
// Before FFmpeg 5.1
return av_get_channel_layout_nb_channels(
buffersink_ctx->inputs[0]->channel_layout);
ret.num_channels = av_get_channel_layout_nb_channels(l->channel_layout);
#endif
} else {
ret.frame_rate = l->frame_rate;
ret.height = l->h;
ret.width = l->w;
}
return ret;
}
////////////////////////////////////////////////////////////////////////////////
......
......@@ -4,6 +4,21 @@
namespace torchaudio {
namespace io {
/// Used to report the output formats of filter graph.
struct FilterGraphOutputInfo {
AVMediaType type = AVMEDIA_TYPE_UNKNOWN;
int format = -1;
// Audio
int sample_rate = -1;
int num_channels = -1;
// Video
AVRational frame_rate = {0, 1};
int height = -1;
int width = -1;
};
class FilterGraph {
AVMediaType media_type;
......@@ -37,6 +52,7 @@ class FilterGraph {
void add_video_src(
AVPixelFormat format,
AVRational time_base,
AVRational frame_rate,
int width,
int height,
AVRational sample_aspect_ratio);
......@@ -53,8 +69,7 @@ class FilterGraph {
// Query methods
//////////////////////////////////////////////////////////////////////////////
[[nodiscard]] AVRational get_output_timebase() const;
[[nodiscard]] int get_output_sample_rate() const;
[[nodiscard]] int get_output_channels() const;
[[nodiscard]] FilterGraphOutputInfo get_output_info() const;
//////////////////////////////////////////////////////////////////////////////
// Streaming process
......
......@@ -57,8 +57,44 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("close", &StreamWriterFileObj::close);
py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local())
.def_readonly("source_index", &OutputStreamInfo::source_index)
.def_readonly(
"filter_description", &OutputStreamInfo::filter_description);
.def_readonly("filter_description", &OutputStreamInfo::filter_description)
.def_property_readonly(
"media_type",
[](const OutputStreamInfo& o) -> std::string {
return av_get_media_type_string(o.media_type);
})
.def_property_readonly(
"format",
[](const OutputStreamInfo& o) -> std::string {
switch (o.media_type) {
case AVMEDIA_TYPE_AUDIO:
return av_get_sample_fmt_name((AVSampleFormat)(o.format));
case AVMEDIA_TYPE_VIDEO:
return av_get_pix_fmt_name((AVPixelFormat)(o.format));
default:
TORCH_INTERNAL_ASSERT(
false,
"FilterGraph is returning unexpected media type: ",
av_get_media_type_string(o.media_type));
}
})
.def_readonly("sample_rate", &OutputStreamInfo::sample_rate)
.def_readonly("num_channels", &OutputStreamInfo::num_channels)
.def_readonly("width", &OutputStreamInfo::width)
.def_readonly("height", &OutputStreamInfo::height)
.def_property_readonly(
"frame_rate", [](const OutputStreamInfo& o) -> double {
if (o.frame_rate.den == 0) {
TORCH_WARN(
o.frame_rate.den,
"Invalid frame rate is found: ",
o.frame_rate.num,
"/",
o.frame_rate.den);
return -1;
}
return static_cast<double>(o.frame_rate.num) / o.frame_rate.den;
});
py::class_<SrcStreamInfo>(m, "SourceStreamInfo", py::module_local())
.def_property_readonly(
"media_type",
......
......@@ -50,6 +50,7 @@ std::unique_ptr<Buffer> get_buffer(
std::unique_ptr<FilterGraph> get_filter_graph(
AVRational input_time_base,
AVCodecParameters* codecpar,
AVRational frame_rate,
const std::string& filter_description) {
auto p = std::make_unique<FilterGraph>(codecpar->codec_type);
......@@ -65,6 +66,7 @@ std::unique_ptr<FilterGraph> get_filter_graph(
p->add_video_src(
static_cast<AVPixelFormat>(codecpar->format),
input_time_base,
frame_rate,
codecpar->width,
codecpar->height,
codecpar->sample_aspect_ratio);
......@@ -85,13 +87,19 @@ Sink::Sink(
AVCodecParameters* codecpar_,
int frames_per_chunk,
int num_chunks,
AVRational frame_rate_,
const c10::optional<std::string>& filter_description_,
const torch::Device& device)
: input_time_base(input_time_base_),
codecpar(codecpar_),
frame_rate(frame_rate_),
filter_description(filter_description_.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
filter(get_filter_graph(input_time_base_, codecpar_, filter_description)),
filter(get_filter_graph(
input_time_base_,
codecpar_,
frame_rate,
filter_description)),
output_time_base(filter->get_output_timebase()),
buffer(get_buffer(
codecpar_->codec_type,
......@@ -125,8 +133,13 @@ std::string Sink::get_filter_description() const {
return filter_description;
}
FilterGraphOutputInfo Sink::get_filter_output_info() const {
return filter->get_output_info();
}
void Sink::flush() {
filter = get_filter_graph(input_time_base, codecpar, filter_description);
filter = get_filter_graph(
input_time_base, codecpar, frame_rate, filter_description);
buffer->flush();
}
......
......@@ -13,6 +13,7 @@ class Sink {
// Parameters for recreating FilterGraph
AVRational input_time_base;
AVCodecParameters* codecpar;
AVRational frame_rate;
std::string filter_description;
std::unique_ptr<FilterGraph> filter;
// time_base of filter graph output, used for PTS calc
......@@ -25,10 +26,13 @@ class Sink {
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
AVRational frame_rate,
const c10::optional<std::string>& filter_description,
const torch::Device& device);
std::string get_filter_description() const;
[[nodiscard]] std::string get_filter_description() const;
[[nodiscard]] FilterGraphOutputInfo get_filter_output_info() const;
int process_frame(AVFrame* frame);
bool is_buffer_ready() const;
......
......@@ -20,6 +20,7 @@ StreamProcessor::StreamProcessor(
KeyType StreamProcessor::add_stream(
int frames_per_chunk,
int num_chunks,
AVRational frame_rate,
const c10::optional<std::string>& filter_description,
const torch::Device& device) {
switch (stream->codecpar->codec_type) {
......@@ -38,6 +39,7 @@ KeyType StreamProcessor::add_stream(
stream->codecpar,
frames_per_chunk,
num_chunks,
frame_rate,
filter_description,
device));
return key;
......@@ -60,6 +62,11 @@ std::string StreamProcessor::get_filter_description(KeyType key) const {
return sinks.at(key).get_filter_description();
}
FilterGraphOutputInfo StreamProcessor::get_filter_output_info(
KeyType key) const {
return sinks.at(key).get_filter_output_info();
}
bool StreamProcessor::is_buffer_ready() const {
for (const auto& it : sinks) {
if (!it.second.buffer->is_ready()) {
......
......@@ -59,6 +59,7 @@ class StreamProcessor {
KeyType add_stream(
int frames_per_chunk,
int num_chunks,
AVRational frame_rate,
const c10::optional<std::string>& filter_description,
const torch::Device& device);
......@@ -72,7 +73,9 @@ class StreamProcessor {
//////////////////////////////////////////////////////////////////////////////
// Query methods
//////////////////////////////////////////////////////////////////////////////
std::string get_filter_description(KeyType key) const;
[[nodiscard]] std::string get_filter_description(KeyType key) const;
[[nodiscard]] FilterGraphOutputInfo get_filter_output_info(KeyType key) const;
bool is_buffer_ready() const;
//////////////////////////////////////////////////////////////////////////////
......
......@@ -187,11 +187,27 @@ int64_t StreamReader::num_out_streams() const {
OutputStreamInfo StreamReader::get_out_stream_info(int i) const {
validate_output_stream_index(i);
OutputStreamInfo ret;
int i_src = stream_indices[i].first;
KeyType key = stream_indices[i].second;
FilterGraphOutputInfo info = processors[i_src]->get_filter_output_info(key);
OutputStreamInfo ret;
ret.source_index = i_src;
ret.filter_description = processors[i_src]->get_filter_description(key);
ret.media_type = info.type;
ret.format = info.format;
switch (info.type) {
case AVMEDIA_TYPE_AUDIO:
ret.sample_rate = info.sample_rate;
ret.num_channels = info.num_channels;
break;
case AVMEDIA_TYPE_VIDEO:
ret.width = info.width;
ret.height = info.height;
ret.frame_rate = info.frame_rate;
break;
default:;
}
return ret;
}
......@@ -336,8 +352,22 @@ void StreamReader::add_stream(
processors[i]->set_discard_timestamp(seek_timestamp);
}
stream->discard = AVDISCARD_DEFAULT;
auto frame_rate = [&]() -> AVRational {
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
return AVRational{0, 1};
case AVMEDIA_TYPE_VIDEO:
return av_guess_frame_rate(pFormatContext, stream, nullptr);
default:
TORCH_INTERNAL_ASSERT(
false,
"Unexpected media type is given: ",
av_get_media_type_string(media_type));
}
}();
int key = processors[i]->add_stream(
frames_per_chunk, num_chunks, filter_desc, device);
frames_per_chunk, num_chunks, frame_rate, filter_desc, device);
stream_indices.push_back(std::make_pair<>(i, key));
}
......
......@@ -101,9 +101,48 @@ struct SrcStreamInfo {
struct OutputStreamInfo {
/// The index of the input source stream
int source_index;
///
/// The stream media type.
///
/// Please see refer to
/// [the FFmpeg
/// documentation](https://ffmpeg.org/doxygen/4.1/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48)
/// for the available values
///
/// @todo Introduce own enum and get rid of FFmpeg dependency
///
AVMediaType media_type = AVMEDIA_TYPE_UNKNOWN;
int format = -1;
/// Filter graph definition, such as
/// ``"aresample=16000,aformat=sample_fmts=fltp"``.
std::string filter_description;
std::string filter_description{};
/// @name AUDIO-SPECIFIC MEMBERS
///@{
/// Sample rate
double sample_rate = -1;
/// The number of channels
int num_channels = -1;
///@}
/// @name VIDEO-SPECIFIC MEMBERS
///@{
/// Width
int width = -1;
/// Height
int height = -1;
/// Frame rate
AVRational frame_rate{0, 1};
///@}
};
/// Stores decoded frames and metadata
......
......@@ -435,6 +435,7 @@ FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
p.add_video_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->framerate,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
......
from __future__ import annotations
from dataclasses import dataclass
from typing import BinaryIO, Dict, Iterator, Optional, Tuple, Union
from typing import BinaryIO, Dict, Iterator, Optional, Tuple, TypeVar, Union
import torch
import torchaudio
......@@ -154,6 +154,80 @@ class OutputStream:
"""Index of the source stream that this output stream is connected."""
filter_description: str
"""Description of filter graph applied to the source stream."""
media_type: str
"""The type of the stream. ``"audio"`` or ``"video"``."""
format: str
"""Media format. Such as ``"s16"`` and ``"yuv420p"``.
Commonly found audio values are;
- ``"u8"``, ``"u8p"``: Unsigned 8-bit unsigned interger.
- ``"s16"``, ``"s16p"``: 16-bit signed integer.
- ``"s32"``, ``"s32p"``: 32-bit signed integer.
- ``"flt"``, ``"fltp"``: 32-bit floating-point.
.. note::
`p` at the end indicates the format is `planar`.
Channels are grouped together instead of interspersed in memory."""
@dataclass
class OutputAudioStream(OutputStream):
"""Information about an audio output stream configured with
:meth:`~torchaudio.io.StreamReader.add_audio_stream` or
:meth:`~torchaudio.io.StreamReader.add_basic_audio_stream`.
In addition to the attributes reported by :class:`OutputStream`,
the following attributes are reported.
"""
sample_rate: float
"""Sample rate of the audio."""
num_channels: int
"""Number of channels."""
@dataclass
class OutputVideoStream(OutputStream):
"""Information about a video output stream configured with
:meth:`~torchaudio.io.StreamReader.add_video_stream` or
:meth:`~torchaudio.io.StreamReader.add_basic_video_stream`.
In addition to the attributes reported by :class:`OutputStream`,
the following attributes are reported.
"""
width: int
"""Width of the video frame in pixel."""
height: int
"""Height of the video frame in pixel."""
frame_rate: float
"""Frame rate."""
def _parse_oi(i):
media_type = i.media_type
if media_type == "audio":
return OutputAudioStream(
source_index=i.source_index,
filter_description=i.filter_description,
media_type=i.media_type,
format=i.format,
sample_rate=i.sample_rate,
num_channels=i.num_channels,
)
if media_type == "video":
return OutputVideoStream(
source_index=i.source_index,
filter_description=i.filter_description,
media_type=i.media_type,
format=i.format,
width=i.width,
height=i.height,
frame_rate=i.frame_rate,
)
raise ValueError(f"Unexpected media_type: {i.media_type}({i})")
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
......@@ -351,6 +425,10 @@ _format_video_args = _format_doc(
)
InputStreamTypes = TypeVar("InputStream", bound=SourceStream)
OutputStreamTypes = TypeVar("OutputStream", bound=OutputStream)
@torchaudio._extension.fail_if_no_ffmpeg
class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk.
......@@ -481,29 +559,33 @@ class StreamReader:
"""
return self._be.get_metadata()
def get_src_stream_info(self, i: int) -> Union[SourceStream, SourceAudioStream, SourceVideoStream]:
def get_src_stream_info(self, i: int) -> InputStreamTypes:
"""Get the metadata of source stream
Args:
i (int): Stream index.
Returns:
InputStreamTypes:
Information about the source stream.
If the source stream is audio type, then :class:`SourceAudioStream` returned.
If it is video type, then :class:`SourceVideoStream` is returned.
Otherwise :class:`SourceStream` class is returned.
If the source stream is audio type, then :class:`~torchaudio.io._stream_reader.SourceAudioStream` returned.
If it is video type, then :class:`~torchaudio.io._stream_reader.SourceVideoStream` is returned.
Otherwise :class:`~torchaudio.io._stream_reader.SourceStream` class is returned.
"""
return _parse_si(self._be.get_src_stream_info(i))
def get_out_stream_info(self, i: int) -> OutputStream:
def get_out_stream_info(self, i: int) -> OutputStreamTypes:
"""Get the metadata of output stream
Args:
i (int): Stream index.
Returns:
OutputStream
OutputStreamTypes
Information about the output stream.
If the output stream is audio type, then :class:`~torchaudio.io._stream_reader.OutputAudioStream` returned.
If it is video type, then :class:`~torchaudio.io._stream_reader.OutputVideoStream` is returned.
"""
info = self._be.get_out_stream_info(i)
return OutputStream(info.source_index, info.filter_description)
return _parse_oi(info)
def seek(self, timestamp: float, mode: str = "precise"):
"""Seek the stream to the given timestamp [second]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment