Commit 10d1bd89 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add metadata to source stream info (#2461)

Summary:
Add metadata, such as ID3 (https://github.com/pytorch/audio/commit/7d98db0567cb60fabcc173949b8c08e3a3487ac2)tag to `StreamReaderSourceAudioStream`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2461

Reviewed By: hwangjeff

Differential Revision: D36985656

Pulled By: mthrok

fbshipit-source-id: e66f9e6e980eb57c378cc643a8979b6b7813dae7
parent 7d98db05
...@@ -89,6 +89,12 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -89,6 +89,12 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s = StreamReader(self.get_src()) s = StreamReader(self.get_src())
assert s.num_src_streams == 6 assert s.num_src_streams == 6
metadata = {
"compatible_brands": "isomiso2avc1mp41",
"encoder": "Lavf58.76.100",
"major_brand": "isom",
"minor_version": "512",
}
expected = [ expected = [
StreamReaderSourceVideoStream( StreamReaderSourceVideoStream(
media_type="video", media_type="video",
...@@ -98,6 +104,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -98,6 +104,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
bit_rate=71925, bit_rate=71925,
num_frames=325, num_frames=325,
bits_per_sample=8, bits_per_sample=8,
metadata=metadata,
width=320, width=320,
height=180, height=180,
frame_rate=25.0, frame_rate=25.0,
...@@ -110,6 +117,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -110,6 +117,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
bit_rate=72093, bit_rate=72093,
num_frames=103, num_frames=103,
bits_per_sample=0, bits_per_sample=0,
metadata=metadata,
sample_rate=8000.0, sample_rate=8000.0,
num_channels=2, num_channels=2,
), ),
...@@ -121,6 +129,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -121,6 +129,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
bit_rate=None, bit_rate=None,
num_frames=None, num_frames=None,
bits_per_sample=None, bits_per_sample=None,
metadata=metadata,
), ),
StreamReaderSourceVideoStream( StreamReaderSourceVideoStream(
media_type="video", media_type="video",
...@@ -130,6 +139,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -130,6 +139,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
bit_rate=128783, bit_rate=128783,
num_frames=390, num_frames=390,
bits_per_sample=8, bits_per_sample=8,
metadata=metadata,
width=480, width=480,
height=270, height=270,
frame_rate=29.97002997002997, frame_rate=29.97002997002997,
...@@ -142,6 +152,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -142,6 +152,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
bit_rate=128837, bit_rate=128837,
num_frames=205, num_frames=205,
bits_per_sample=0, bits_per_sample=0,
metadata=metadata,
sample_rate=16000.0, sample_rate=16000.0,
num_channels=2, num_channels=2,
), ),
...@@ -153,11 +164,34 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -153,11 +164,34 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
bit_rate=None, bit_rate=None,
num_frames=None, num_frames=None,
bits_per_sample=None, bits_per_sample=None,
metadata=metadata,
), ),
] ]
output = [s.get_src_stream_info(i) for i in range(6)] output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output assert expected == output
def test_id3tag(self):
s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3"))
output = s.get_src_stream_info(s.default_audio_stream)
expected = StreamReaderSourceAudioStream(
media_type="audio",
codec="mp3",
codec_long_name="MP3 (MPEG audio layer 3)",
format="fltp",
bit_rate=210571,
num_frames=0,
bits_per_sample=0,
metadata={
"title": "SoundBible.com Must Credit",
"artist": "SoundBible.com Must Credit",
"date": "2017",
},
sample_rate=44100.0,
num_channels=2,
)
assert output == expected
def test_src_info_invalid_index(self): def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid""" """`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = StreamReader(self.get_src()) s = StreamReader(self.get_src())
......
...@@ -22,7 +22,9 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { ...@@ -22,7 +22,9 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def( .def(
"find_best_video_stream", "find_best_video_stream",
&StreamReaderFileObj::find_best_video_stream) &StreamReaderFileObj::find_best_video_stream)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info) .def(
"get_src_stream_info",
&StreamReaderFileObj::get_src_stream_info_pybind)
.def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info) .def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
.def("seek", &StreamReaderFileObj::seek) .def("seek", &StreamReaderFileObj::seek)
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream) .def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
......
...@@ -71,6 +71,18 @@ int64_t StreamReader::num_src_streams() const { ...@@ -71,6 +71,18 @@ int64_t StreamReader::num_src_streams() const {
return pFormatContext->nb_streams; return pFormatContext->nb_streams;
} }
namespace {
c10::Dict<std::string, std::string> parse_metadata(
const AVDictionary* metadata) {
AVDictionaryEntry* tag = nullptr;
c10::Dict<std::string, std::string> ret;
while ((tag = av_dict_get(metadata, "", tag, AV_DICT_IGNORE_SUFFIX))) {
ret.insert(std::string(tag->key), std::string(tag->value));
}
return ret;
}
} // namespace
SrcStreamInfo StreamReader::get_src_stream_info(int i) const { SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
validate_src_stream_index(i); validate_src_stream_index(i);
AVStream* stream = pFormatContext->streams[i]; AVStream* stream = pFormatContext->streams[i];
...@@ -81,11 +93,13 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const { ...@@ -81,11 +93,13 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
ret.bit_rate = codecpar->bit_rate; ret.bit_rate = codecpar->bit_rate;
ret.num_frames = stream->nb_frames; ret.num_frames = stream->nb_frames;
ret.bits_per_sample = codecpar->bits_per_raw_sample; ret.bits_per_sample = codecpar->bits_per_raw_sample;
ret.metadata = parse_metadata(pFormatContext->metadata);
const AVCodecDescriptor* desc = avcodec_descriptor_get(codecpar->codec_id); const AVCodecDescriptor* desc = avcodec_descriptor_get(codecpar->codec_id);
if (desc) { if (desc) {
ret.codec_name = desc->name; ret.codec_name = desc->name;
ret.codec_long_name = desc->long_name; ret.codec_long_name = desc->long_name;
} }
switch (codecpar->codec_type) { switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO: { case AVMEDIA_TYPE_AUDIO: {
AVSampleFormat smp_fmt = static_cast<AVSampleFormat>(codecpar->format); AVSampleFormat smp_fmt = static_cast<AVSampleFormat>(codecpar->format);
......
...@@ -4,6 +4,17 @@ namespace torchaudio { ...@@ -4,6 +4,17 @@ namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace { namespace {
// TODO:
// merge the implementation with the one from stream_reader_binding.cpp
std::map<std::string, std::string> convert_map(
const c10::Dict<std::string, std::string>& src) {
std::map<std::string, std::string> ret;
for (const auto& it : src) {
ret.insert({it.key(), it.value()});
}
return ret;
}
SrcInfo convert(SrcStreamInfo ssi) { SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple( return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type), av_get_media_type_string(ssi.media_type),
...@@ -13,6 +24,24 @@ SrcInfo convert(SrcStreamInfo ssi) { ...@@ -13,6 +24,24 @@ SrcInfo convert(SrcStreamInfo ssi) {
ssi.bit_rate, ssi.bit_rate,
ssi.num_frames, ssi.num_frames,
ssi.bits_per_sample, ssi.bits_per_sample,
ssi.metadata,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
SrcInfoPyBind convert_pybind(SrcStreamInfo ssi) {
return SrcInfoPyBind(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
convert_map(ssi.metadata),
ssi.sample_rate, ssi.sample_rate,
ssi.num_channels, ssi.num_channels,
ssi.width, ssi.width,
...@@ -33,6 +62,10 @@ SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) { ...@@ -33,6 +62,10 @@ SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(StreamReader::get_src_stream_info(static_cast<int>(i))); return convert(StreamReader::get_src_stream_info(static_cast<int>(i)));
} }
SrcInfoPyBind StreamReaderBinding::get_src_stream_info_pybind(int64_t i) {
return convert_pybind(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) { OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(StreamReader::get_out_stream_info(static_cast<int>(i))); return convert(StreamReader::get_out_stream_info(static_cast<int>(i)));
} }
......
...@@ -5,6 +5,14 @@ ...@@ -5,6 +5,14 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
// Because TorchScript requires c10::Dict type to pass dict,
// while PyBind11 requires std::map type to pass dict,
// we duplicate the return tuple.
// Even though all the PyBind-based implementations are placed
// in `pybind` directory, because std::map does not require pybind11
// header, we define both of them here, for the sake of
// better locality/maintainability.
using SrcInfo = std::tuple< using SrcInfo = std::tuple<
std::string, // media_type std::string, // media_type
std::string, // codec name std::string, // codec name
...@@ -13,6 +21,25 @@ using SrcInfo = std::tuple< ...@@ -13,6 +21,25 @@ using SrcInfo = std::tuple<
int64_t, // bit_rate int64_t, // bit_rate
int64_t, // num_frames int64_t, // num_frames
int64_t, // bits_per_sample int64_t, // bits_per_sample
c10::Dict<std::string, std::string>, // metadata
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
using SrcInfoPyBind = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
int64_t, // num_frames
int64_t, // bits_per_sample
std::map<std::string, std::string>, // metadata
// Audio // Audio
double, // sample_rate double, // sample_rate
int64_t, // num_channels int64_t, // num_channels
...@@ -33,6 +60,7 @@ struct StreamReaderBinding : public StreamReader, ...@@ -33,6 +60,7 @@ struct StreamReaderBinding : public StreamReader,
public torch::CustomClassHolder { public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatContextPtr&& p); explicit StreamReaderBinding(AVFormatContextPtr&& p);
SrcInfo get_src_stream_info(int64_t i); SrcInfo get_src_stream_info(int64_t i);
SrcInfoPyBind get_src_stream_info_pybind(int64_t i);
OutInfo get_out_stream_info(int64_t i); OutInfo get_out_stream_info(int64_t i);
int64_t process_packet( int64_t process_packet(
......
...@@ -14,6 +14,7 @@ struct SrcStreamInfo { ...@@ -14,6 +14,7 @@ struct SrcStreamInfo {
int64_t bit_rate = 0; int64_t bit_rate = 0;
int64_t num_frames = 0; int64_t num_frames = 0;
int bits_per_sample = 0; int bits_per_sample = 0;
c10::Dict<std::string, std::string> metadata{};
// Audio // Audio
double sample_rate = 0; double sample_rate = 0;
int num_channels = 0; int num_channels = 0;
......
...@@ -12,9 +12,9 @@ def _info_audio( ...@@ -12,9 +12,9 @@ def _info_audio(
i = s.find_best_audio_stream() i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i) sinfo = s.get_src_stream_info(i)
return AudioMetaData( return AudioMetaData(
int(sinfo[7]), int(sinfo[8]),
sinfo[5], sinfo[5],
sinfo[8], sinfo[9],
sinfo[6], sinfo[6],
sinfo[1].upper(), sinfo[1].upper(),
) )
...@@ -73,7 +73,7 @@ def _load_audio( ...@@ -73,7 +73,7 @@ def _load_audio(
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
i = s.find_best_audio_stream() i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i) sinfo = s.get_src_stream_info(i)
sample_rate = int(sinfo[7]) sample_rate = int(sinfo[8])
option: Dict[str, str] = {} option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option) s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.process_all_packets() s.process_all_packets()
......
...@@ -61,6 +61,9 @@ class StreamReaderSourceStream: ...@@ -61,6 +61,9 @@ class StreamReaderSourceStream:
"""This is the number of valid bits in each output sample. """This is the number of valid bits in each output sample.
For compressed format, it can be 0. For compressed format, it can be 0.
""" """
metadata: Dict[str, str]
"""Metadata attached to the source media.
Note that metadata is common across the source streams."""
@dataclass @dataclass
...@@ -108,13 +111,14 @@ _FORMAT = 3 ...@@ -108,13 +111,14 @@ _FORMAT = 3
_BIT_RATE = 4 _BIT_RATE = 4
_NUM_FRAMES = 5 _NUM_FRAMES = 5
_BPS = 6 _BPS = 6
_METADATA = 7
# - AUDIO # - AUDIO
_SAMPLE_RATE = 7 _SAMPLE_RATE = 8
_NUM_CHANNELS = 8 _NUM_CHANNELS = 9
# - VIDEO # - VIDEO
_WIDTH = 9 _WIDTH = 10
_HEIGHT = 10 _HEIGHT = 11
_FRAME_RATE = 11 _FRAME_RATE = 12
def _parse_si(i): def _parse_si(i):
...@@ -125,6 +129,7 @@ def _parse_si(i): ...@@ -125,6 +129,7 @@ def _parse_si(i):
bit_rate = i[_BIT_RATE] bit_rate = i[_BIT_RATE]
num_frames = i[_NUM_FRAMES] num_frames = i[_NUM_FRAMES]
bps = i[_BPS] bps = i[_BPS]
metadata = i[_METADATA]
if media_type == "audio": if media_type == "audio":
return StreamReaderSourceAudioStream( return StreamReaderSourceAudioStream(
media_type=media_type, media_type=media_type,
...@@ -134,6 +139,7 @@ def _parse_si(i): ...@@ -134,6 +139,7 @@ def _parse_si(i):
bit_rate=bit_rate, bit_rate=bit_rate,
num_frames=num_frames, num_frames=num_frames,
bits_per_sample=bps, bits_per_sample=bps,
metadata=metadata,
sample_rate=i[_SAMPLE_RATE], sample_rate=i[_SAMPLE_RATE],
num_channels=i[_NUM_CHANNELS], num_channels=i[_NUM_CHANNELS],
) )
...@@ -146,6 +152,7 @@ def _parse_si(i): ...@@ -146,6 +152,7 @@ def _parse_si(i):
bit_rate=bit_rate, bit_rate=bit_rate,
num_frames=num_frames, num_frames=num_frames,
bits_per_sample=bps, bits_per_sample=bps,
metadata=metadata,
width=i[_WIDTH], width=i[_WIDTH],
height=i[_HEIGHT], height=i[_HEIGHT],
frame_rate=i[_FRAME_RATE], frame_rate=i[_FRAME_RATE],
...@@ -158,6 +165,7 @@ def _parse_si(i): ...@@ -158,6 +165,7 @@ def _parse_si(i):
bit_rate=None, bit_rate=None,
num_frames=None, num_frames=None,
bits_per_sample=None, bits_per_sample=None,
metadata=metadata,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment