Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ffeba11a
Commit
ffeba11a
authored
Sep 02, 2024
by
mayp777
Browse files
UPDATE
parent
29deb085
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2741 additions
and
165 deletions
+2741
-165
torchaudio/csrc/ffmpeg/filter_graph.h
torchaudio/csrc/ffmpeg/filter_graph.h
+36
-9
torchaudio/csrc/ffmpeg/hw_context.cpp
torchaudio/csrc/ffmpeg/hw_context.cpp
+40
-0
torchaudio/csrc/ffmpeg/hw_context.h
torchaudio/csrc/ffmpeg/hw_context.h
+11
-0
torchaudio/csrc/ffmpeg/pybind/pybind.cpp
torchaudio/csrc/ffmpeg/pybind/pybind.cpp
+320
-14
torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.cpp
...audio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.cpp
+129
-0
torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h
torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h
+33
-0
torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.cpp
...dio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.cpp
+33
-0
torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h
...audio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h
+23
-0
torchaudio/csrc/ffmpeg/stream_reader/conversion.cpp
torchaudio/csrc/ffmpeg/stream_reader/conversion.cpp
+628
-0
torchaudio/csrc/ffmpeg/stream_reader/conversion.h
torchaudio/csrc/ffmpeg/stream_reader/conversion.h
+129
-0
torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.cpp
torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.cpp
+20
-0
torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h
torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h
+16
-0
torchaudio/csrc/ffmpeg/stream_reader/post_process.cpp
torchaudio/csrc/ffmpeg/stream_reader/post_process.cpp
+620
-0
torchaudio/csrc/ffmpeg/stream_reader/post_process.h
torchaudio/csrc/ffmpeg/stream_reader/post_process.h
+34
-0
torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp
torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp
+317
-44
torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h
torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h
+35
-20
torchaudio/csrc/ffmpeg/stream_reader/stream_reader.cpp
torchaudio/csrc/ffmpeg/stream_reader/stream_reader.cpp
+317
-78
No files found.
Too many changes to show.
To preserve performance only
337 of 337+
files are displayed.
Plain diff
Email patch
torchaudio/csrc/ffmpeg/filter_graph.h
View file @
ffeba11a
...
...
@@ -2,12 +2,27 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace
torchaudio
{
namespace
ffmpeg
{
namespace
io
{
class
FilterGraph
{
AVMediaType
media_type
;
/// Used to report the output formats of filter graph.
struct
FilterGraphOutputInfo
{
AVMediaType
type
=
AVMEDIA_TYPE_UNKNOWN
;
int
format
=
-
1
;
AVRational
time_base
=
{
1
,
1
};
// Audio
int
sample_rate
=
-
1
;
int
num_channels
=
-
1
;
AVFilterGraphPtr
pFilterGraph
;
// Video
AVRational
frame_rate
=
{
0
,
1
};
int
height
=
-
1
;
int
width
=
-
1
;
};
class
FilterGraph
{
AVFilterGraphPtr
graph
;
// AVFilterContext is freed as a part of AVFilterGraph
// so we do not manage the resource.
...
...
@@ -15,7 +30,7 @@ class FilterGraph {
AVFilterContext
*
buffersink_ctx
=
nullptr
;
public:
explicit
FilterGraph
(
AVMediaType
media_type
);
explicit
FilterGraph
();
// Custom destructor to release AVFilterGraph*
~
FilterGraph
()
=
default
;
// Non-copyable
...
...
@@ -37,17 +52,29 @@ class FilterGraph {
void
add_video_src
(
AVPixelFormat
format
,
AVRational
time_base
,
AVRational
frame_rate
,
int
width
,
int
height
,
AVRational
sample_aspect_ratio
);
void
add_
src
(
const
std
::
string
&
arg
);
void
add_
audio_sink
(
);
void
add_sink
();
void
add_
video_
sink
();
void
add_process
(
const
std
::
string
&
filter_description
);
void
create_filter
();
void
create_filter
(
AVBufferRef
*
hw_frames_ctx
=
nullptr
);
private:
void
add_src
(
const
AVFilter
*
buffersrc
,
const
std
::
string
&
arg
);
void
add_sink
(
const
AVFilter
*
buffersrc
);
//////////////////////////////////////////////////////////////////////////////
// Query methods
//////////////////////////////////////////////////////////////////////////////
public:
[[
nodiscard
]]
FilterGraphOutputInfo
get_output_info
()
const
;
//////////////////////////////////////////////////////////////////////////////
// Streaming process
...
...
@@ -57,5 +84,5 @@ class FilterGraph {
int
get_frame
(
AVFrame
*
pOutputFrame
);
};
}
// namespace
ffmpeg
}
// namespace
io
}
// namespace torchaudio
torchaudio/csrc/ffmpeg/hw_context.cpp
0 → 100644
View file @
ffeba11a
#include <torchaudio/csrc/ffmpeg/hw_context.h>
namespace
torchaudio
::
io
{
namespace
{
static
std
::
mutex
MUTEX
;
static
std
::
map
<
int
,
AVBufferRefPtr
>
CUDA_CONTEXT_CACHE
;
}
// namespace
AVBufferRef
*
get_cuda_context
(
int
index
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
MUTEX
);
if
(
index
==
-
1
)
{
index
=
0
;
}
if
(
CUDA_CONTEXT_CACHE
.
count
(
index
)
==
0
)
{
AVBufferRef
*
p
=
nullptr
;
int
ret
=
av_hwdevice_ctx_create
(
&
p
,
AV_HWDEVICE_TYPE_CUDA
,
std
::
to_string
(
index
).
c_str
(),
nullptr
,
0
);
TORCH_CHECK
(
ret
>=
0
,
"Failed to create CUDA device context on device "
,
index
,
"("
,
av_err2string
(
ret
),
")"
);
assert
(
p
);
CUDA_CONTEXT_CACHE
.
emplace
(
index
,
p
);
return
p
;
}
AVBufferRefPtr
&
buffer
=
CUDA_CONTEXT_CACHE
.
at
(
index
);
return
buffer
;
}
void
clear_cuda_context_cache
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
MUTEX
);
CUDA_CONTEXT_CACHE
.
clear
();
}
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/hw_context.h
0 → 100644
View file @
ffeba11a
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace
torchaudio
::
io
{
AVBufferRef
*
get_cuda_context
(
int
index
);
void
clear_cuda_context_cache
();
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/pybind/pybind.cpp
View file @
ffeba11a
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace
torchaudio
{
namespace
ffmpeg
{
namespace
torchaudio
::
io
{
namespace
{
PYBIND11_MODULE
(
_torchaudio_ffmpeg
,
m
)
{
py
::
class_
<
StreamWriterFileObj
,
c10
::
intrusive_ptr
<
StreamWriterFileObj
>>
(
m
,
"StreamWriterFileObj"
)
std
::
map
<
std
::
string
,
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
>>
get_versions
()
{
std
::
map
<
std
::
string
,
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
>>
ret
;
#define add_version(NAME) \
{ \
int ver = NAME##_version(); \
ret.emplace( \
"lib" #NAME, \
std::make_tuple<>( \
AV_VERSION_MAJOR(ver), \
AV_VERSION_MINOR(ver), \
AV_VERSION_MICRO(ver))); \
}
add_version
(
avutil
);
add_version
(
avcodec
);
add_version
(
avformat
);
add_version
(
avfilter
);
add_version
(
avdevice
);
return
ret
;
#undef add_version
}
std
::
map
<
std
::
string
,
std
::
string
>
get_demuxers
(
bool
req_device
)
{
std
::
map
<
std
::
string
,
std
::
string
>
ret
;
const
AVInputFormat
*
fmt
=
nullptr
;
void
*
i
=
nullptr
;
while
((
fmt
=
av_demuxer_iterate
(
&
i
)))
{
assert
(
fmt
);
bool
is_device
=
[
&
]()
{
const
AVClass
*
avclass
=
fmt
->
priv_class
;
return
avclass
&&
AV_IS_INPUT_DEVICE
(
avclass
->
category
);
}();
if
(
req_device
==
is_device
)
{
ret
.
emplace
(
fmt
->
name
,
fmt
->
long_name
);
}
}
return
ret
;
}
std
::
map
<
std
::
string
,
std
::
string
>
get_muxers
(
bool
req_device
)
{
std
::
map
<
std
::
string
,
std
::
string
>
ret
;
const
AVOutputFormat
*
fmt
=
nullptr
;
void
*
i
=
nullptr
;
while
((
fmt
=
av_muxer_iterate
(
&
i
)))
{
assert
(
fmt
);
bool
is_device
=
[
&
]()
{
const
AVClass
*
avclass
=
fmt
->
priv_class
;
return
avclass
&&
AV_IS_OUTPUT_DEVICE
(
avclass
->
category
);
}();
if
(
req_device
==
is_device
)
{
ret
.
emplace
(
fmt
->
name
,
fmt
->
long_name
);
}
}
return
ret
;
}
std
::
map
<
std
::
string
,
std
::
string
>
get_codecs
(
AVMediaType
type
,
bool
req_encoder
)
{
const
AVCodec
*
c
=
nullptr
;
void
*
i
=
nullptr
;
std
::
map
<
std
::
string
,
std
::
string
>
ret
;
while
((
c
=
av_codec_iterate
(
&
i
)))
{
assert
(
c
);
if
((
req_encoder
&&
av_codec_is_encoder
(
c
))
||
(
!
req_encoder
&&
av_codec_is_decoder
(
c
)))
{
if
(
c
->
type
==
type
&&
c
->
name
)
{
ret
.
emplace
(
c
->
name
,
c
->
long_name
?
c
->
long_name
:
""
);
}
}
}
return
ret
;
}
std
::
vector
<
std
::
string
>
get_protocols
(
bool
output
)
{
void
*
opaque
=
nullptr
;
const
char
*
name
=
nullptr
;
std
::
vector
<
std
::
string
>
ret
;
while
((
name
=
avio_enum_protocols
(
&
opaque
,
output
)))
{
assert
(
name
);
ret
.
emplace_back
(
name
);
}
return
ret
;
}
std
::
string
get_build_config
()
{
return
avcodec_configuration
();
}
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer FileObj
//////////////////////////////////////////////////////////////////////////////
struct
FileObj
{
py
::
object
fileobj
;
int
buffer_size
;
};
namespace
{
static
int
read_func
(
void
*
opaque
,
uint8_t
*
buf
,
int
buf_size
)
{
FileObj
*
fileobj
=
static_cast
<
FileObj
*>
(
opaque
);
buf_size
=
FFMIN
(
buf_size
,
fileobj
->
buffer_size
);
int
num_read
=
0
;
while
(
num_read
<
buf_size
)
{
int
request
=
buf_size
-
num_read
;
auto
chunk
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
->
fileobj
.
attr
(
"read"
)(
request
)));
auto
chunk_len
=
chunk
.
length
();
if
(
chunk_len
==
0
)
{
break
;
}
TORCH_CHECK
(
chunk_len
<=
request
,
"Requested up to "
,
request
,
" bytes but, received "
,
chunk_len
,
" bytes. The given object does not confirm to read protocol of file object."
);
memcpy
(
buf
,
chunk
.
data
(),
chunk_len
);
buf
+=
chunk_len
;
num_read
+=
static_cast
<
int
>
(
chunk_len
);
}
return
num_read
==
0
?
AVERROR_EOF
:
num_read
;
}
static
int
write_func
(
void
*
opaque
,
uint8_t
*
buf
,
int
buf_size
)
{
FileObj
*
fileobj
=
static_cast
<
FileObj
*>
(
opaque
);
buf_size
=
FFMIN
(
buf_size
,
fileobj
->
buffer_size
);
py
::
bytes
b
(
reinterpret_cast
<
const
char
*>
(
buf
),
buf_size
);
// TODO: check the return value
fileobj
->
fileobj
.
attr
(
"write"
)(
b
);
return
buf_size
;
}
static
int64_t
seek_func
(
void
*
opaque
,
int64_t
offset
,
int
whence
)
{
// We do not know the file size.
if
(
whence
==
AVSEEK_SIZE
)
{
return
AVERROR
(
EIO
);
}
FileObj
*
fileobj
=
static_cast
<
FileObj
*>
(
opaque
);
return
py
::
cast
<
int64_t
>
(
fileobj
->
fileobj
.
attr
(
"seek"
)(
offset
,
whence
));
}
}
// namespace
struct
StreamReaderFileObj
:
private
FileObj
,
public
StreamReaderCustomIO
{
StreamReaderFileObj
(
py
::
object
fileobj
,
const
c10
::
optional
<
std
::
string
>&
format
,
const
c10
::
optional
<
std
::
map
<
std
::
string
,
std
::
string
>>&
option
,
int
buffer_size
)
:
FileObj
{
fileobj
,
buffer_size
},
StreamReaderCustomIO
(
this
,
format
,
buffer_size
,
read_func
,
py
::
hasattr
(
fileobj
,
"seek"
)
?
&
seek_func
:
nullptr
,
option
)
{}
};
struct
StreamWriterFileObj
:
private
FileObj
,
public
StreamWriterCustomIO
{
StreamWriterFileObj
(
py
::
object
fileobj
,
const
c10
::
optional
<
std
::
string
>&
format
,
int
buffer_size
)
:
FileObj
{
fileobj
,
buffer_size
},
StreamWriterCustomIO
(
this
,
format
,
buffer_size
,
write_func
,
py
::
hasattr
(
fileobj
,
"seek"
)
?
&
seek_func
:
nullptr
)
{}
};
#ifndef TORCHAUDIO_FFMPEG_EXT_NAME
#error TORCHAUDIO_FFMPEG_EXT_NAME must be defined.
#endif
PYBIND11_MODULE
(
TORCHAUDIO_FFMPEG_EXT_NAME
,
m
)
{
m
.
def
(
"init"
,
[]()
{
avdevice_register_all
();
});
m
.
def
(
"get_log_level"
,
[]()
{
return
av_log_get_level
();
});
m
.
def
(
"set_log_level"
,
[](
int
level
)
{
av_log_set_level
(
level
);
});
m
.
def
(
"get_versions"
,
&
get_versions
);
m
.
def
(
"get_muxers"
,
[]()
{
return
get_muxers
(
false
);
});
m
.
def
(
"get_demuxers"
,
[]()
{
return
get_demuxers
(
false
);
});
m
.
def
(
"get_input_devices"
,
[]()
{
return
get_demuxers
(
true
);
});
m
.
def
(
"get_build_config"
,
&
get_build_config
);
m
.
def
(
"get_output_devices"
,
[]()
{
return
get_muxers
(
true
);
});
m
.
def
(
"get_audio_decoders"
,
[]()
{
return
get_codecs
(
AVMEDIA_TYPE_AUDIO
,
false
);
});
m
.
def
(
"get_audio_encoders"
,
[]()
{
return
get_codecs
(
AVMEDIA_TYPE_AUDIO
,
true
);
});
m
.
def
(
"get_video_decoders"
,
[]()
{
return
get_codecs
(
AVMEDIA_TYPE_VIDEO
,
false
);
});
m
.
def
(
"get_video_encoders"
,
[]()
{
return
get_codecs
(
AVMEDIA_TYPE_VIDEO
,
true
);
});
m
.
def
(
"get_input_protocols"
,
[]()
{
return
get_protocols
(
false
);
});
m
.
def
(
"get_output_protocols"
,
[]()
{
return
get_protocols
(
true
);
});
m
.
def
(
"clear_cuda_context_cache"
,
&
clear_cuda_context_cache
);
py
::
class_
<
Chunk
>
(
m
,
"Chunk"
,
py
::
module_local
())
.
def_readwrite
(
"frames"
,
&
Chunk
::
frames
)
.
def_readwrite
(
"pts"
,
&
Chunk
::
pts
);
py
::
class_
<
CodecConfig
>
(
m
,
"CodecConfig"
,
py
::
module_local
())
.
def
(
py
::
init
<
int
,
int
,
const
c10
::
optional
<
int
>&
,
int
,
int
>
());
py
::
class_
<
StreamWriter
>
(
m
,
"StreamWriter"
,
py
::
module_local
())
.
def
(
py
::
init
<
const
std
::
string
&
,
const
c10
::
optional
<
std
::
string
>&>
())
.
def
(
"set_metadata"
,
&
StreamWriter
::
set_metadata
)
.
def
(
"add_audio_stream"
,
&
StreamWriter
::
add_audio_stream
)
.
def
(
"add_video_stream"
,
&
StreamWriter
::
add_video_stream
)
.
def
(
"dump_format"
,
&
StreamWriter
::
dump_format
)
.
def
(
"open"
,
&
StreamWriter
::
open
)
.
def
(
"write_audio_chunk"
,
&
StreamWriter
::
write_audio_chunk
)
.
def
(
"write_video_chunk"
,
&
StreamWriter
::
write_video_chunk
)
.
def
(
"flush"
,
&
StreamWriter
::
flush
)
.
def
(
"close"
,
&
StreamWriter
::
close
);
py
::
class_
<
StreamWriterFileObj
>
(
m
,
"StreamWriterFileObj"
,
py
::
module_local
())
.
def
(
py
::
init
<
py
::
object
,
const
c10
::
optional
<
std
::
string
>&
,
int64_t
>
())
.
def
(
"set_metadata"
,
&
StreamWriterFileObj
::
set_metadata
)
.
def
(
"add_audio_stream"
,
&
StreamWriterFileObj
::
add_audio_stream
)
...
...
@@ -20,12 +243,92 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.
def
(
"write_video_chunk"
,
&
StreamWriterFileObj
::
write_video_chunk
)
.
def
(
"flush"
,
&
StreamWriterFileObj
::
flush
)
.
def
(
"close"
,
&
StreamWriterFileObj
::
close
);
py
::
class_
<
StreamReaderFileObj
,
c10
::
intrusive_ptr
<
StreamReaderFileObj
>>
(
m
,
"StreamReaderFileObj"
)
py
::
class_
<
OutputStreamInfo
>
(
m
,
"OutputStreamInfo"
,
py
::
module_local
())
.
def_readonly
(
"source_index"
,
&
OutputStreamInfo
::
source_index
)
.
def_readonly
(
"filter_description"
,
&
OutputStreamInfo
::
filter_description
)
.
def_property_readonly
(
"media_type"
,
[](
const
OutputStreamInfo
&
o
)
->
std
::
string
{
return
av_get_media_type_string
(
o
.
media_type
);
})
.
def_property_readonly
(
"format"
,
[](
const
OutputStreamInfo
&
o
)
->
std
::
string
{
switch
(
o
.
media_type
)
{
case
AVMEDIA_TYPE_AUDIO
:
return
av_get_sample_fmt_name
((
AVSampleFormat
)(
o
.
format
));
case
AVMEDIA_TYPE_VIDEO
:
return
av_get_pix_fmt_name
((
AVPixelFormat
)(
o
.
format
));
default:
TORCH_INTERNAL_ASSERT
(
false
,
"FilterGraph is returning unexpected media type: "
,
av_get_media_type_string
(
o
.
media_type
));
}
})
.
def_readonly
(
"sample_rate"
,
&
OutputStreamInfo
::
sample_rate
)
.
def_readonly
(
"num_channels"
,
&
OutputStreamInfo
::
num_channels
)
.
def_readonly
(
"width"
,
&
OutputStreamInfo
::
width
)
.
def_readonly
(
"height"
,
&
OutputStreamInfo
::
height
)
.
def_property_readonly
(
"frame_rate"
,
[](
const
OutputStreamInfo
&
o
)
->
double
{
if
(
o
.
frame_rate
.
den
==
0
)
{
TORCH_WARN
(
"Invalid frame rate is found: "
,
o
.
frame_rate
.
num
,
"/"
,
o
.
frame_rate
.
den
);
return
-
1
;
}
return
static_cast
<
double
>
(
o
.
frame_rate
.
num
)
/
o
.
frame_rate
.
den
;
});
py
::
class_
<
SrcStreamInfo
>
(
m
,
"SourceStreamInfo"
,
py
::
module_local
())
.
def_property_readonly
(
"media_type"
,
[](
const
SrcStreamInfo
&
s
)
{
return
av_get_media_type_string
(
s
.
media_type
);
})
.
def_readonly
(
"codec_name"
,
&
SrcStreamInfo
::
codec_name
)
.
def_readonly
(
"codec_long_name"
,
&
SrcStreamInfo
::
codec_long_name
)
.
def_readonly
(
"format"
,
&
SrcStreamInfo
::
fmt_name
)
.
def_readonly
(
"bit_rate"
,
&
SrcStreamInfo
::
bit_rate
)
.
def_readonly
(
"num_frames"
,
&
SrcStreamInfo
::
num_frames
)
.
def_readonly
(
"bits_per_sample"
,
&
SrcStreamInfo
::
bits_per_sample
)
.
def_readonly
(
"metadata"
,
&
SrcStreamInfo
::
metadata
)
.
def_readonly
(
"sample_rate"
,
&
SrcStreamInfo
::
sample_rate
)
.
def_readonly
(
"num_channels"
,
&
SrcStreamInfo
::
num_channels
)
.
def_readonly
(
"width"
,
&
SrcStreamInfo
::
width
)
.
def_readonly
(
"height"
,
&
SrcStreamInfo
::
height
)
.
def_readonly
(
"frame_rate"
,
&
SrcStreamInfo
::
frame_rate
);
py
::
class_
<
StreamReader
>
(
m
,
"StreamReader"
,
py
::
module_local
())
.
def
(
py
::
init
<
const
std
::
string
&
,
const
c10
::
optional
<
std
::
string
>&
,
const
c10
::
optional
<
OptionDict
>&>
())
.
def
(
"num_src_streams"
,
&
StreamReader
::
num_src_streams
)
.
def
(
"num_out_streams"
,
&
StreamReader
::
num_out_streams
)
.
def
(
"find_best_audio_stream"
,
&
StreamReader
::
find_best_audio_stream
)
.
def
(
"find_best_video_stream"
,
&
StreamReader
::
find_best_video_stream
)
.
def
(
"get_metadata"
,
&
StreamReader
::
get_metadata
)
.
def
(
"get_src_stream_info"
,
&
StreamReader
::
get_src_stream_info
)
.
def
(
"get_out_stream_info"
,
&
StreamReader
::
get_out_stream_info
)
.
def
(
"seek"
,
&
StreamReader
::
seek
)
.
def
(
"add_audio_stream"
,
&
StreamReader
::
add_audio_stream
)
.
def
(
"add_video_stream"
,
&
StreamReader
::
add_video_stream
)
.
def
(
"remove_stream"
,
&
StreamReader
::
remove_stream
)
.
def
(
"process_packet"
,
py
::
overload_cast
<
const
c10
::
optional
<
double
>&
,
const
double
>
(
&
StreamReader
::
process_packet
))
.
def
(
"process_all_packets"
,
&
StreamReader
::
process_all_packets
)
.
def
(
"fill_buffer"
,
&
StreamReader
::
fill_buffer
)
.
def
(
"is_buffer_ready"
,
&
StreamReader
::
is_buffer_ready
)
.
def
(
"pop_chunks"
,
&
StreamReader
::
pop_chunks
);
py
::
class_
<
StreamReaderFileObj
>
(
m
,
"StreamReaderFileObj"
,
py
::
module_local
())
.
def
(
py
::
init
<
py
::
object
,
const
c10
::
optional
<
std
::
string
>&
,
const
c10
::
optional
<
Option
Map
>&
,
const
c10
::
optional
<
Option
Dict
>&
,
int64_t
>
())
.
def
(
"num_src_streams"
,
&
StreamReaderFileObj
::
num_src_streams
)
.
def
(
"num_out_streams"
,
&
StreamReaderFileObj
::
num_out_streams
)
...
...
@@ -42,12 +345,15 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.
def
(
"add_audio_stream"
,
&
StreamReaderFileObj
::
add_audio_stream
)
.
def
(
"add_video_stream"
,
&
StreamReaderFileObj
::
add_video_stream
)
.
def
(
"remove_stream"
,
&
StreamReaderFileObj
::
remove_stream
)
.
def
(
"process_packet"
,
&
StreamReaderFileObj
::
process_packet
)
.
def
(
"process_packet"
,
py
::
overload_cast
<
const
c10
::
optional
<
double
>&
,
const
double
>
(
&
StreamReader
::
process_packet
))
.
def
(
"process_all_packets"
,
&
StreamReaderFileObj
::
process_all_packets
)
.
def
(
"fill_buffer"
,
&
StreamReaderFileObj
::
fill_buffer
)
.
def
(
"is_buffer_ready"
,
&
StreamReaderFileObj
::
is_buffer_ready
)
.
def
(
"pop_chunks"
,
&
StreamReaderFileObj
::
pop_chunks
);
}
}
// namespace
}
// namespace ffmpeg
}
// namespace torchaudio
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.cpp
0 → 100644
View file @
ffeba11a
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
namespace
torchaudio
::
io
::
detail
{
ChunkedBuffer
::
ChunkedBuffer
(
AVRational
time_base
,
int
frames_per_chunk_
,
int
num_chunks_
)
:
time_base
(
time_base
),
frames_per_chunk
(
frames_per_chunk_
),
num_chunks
(
num_chunks_
){};
bool
ChunkedBuffer
::
is_ready
()
const
{
return
num_buffered_frames
>=
frames_per_chunk
;
}
void
ChunkedBuffer
::
push_frame
(
torch
::
Tensor
frame
,
int64_t
pts_
)
{
using
namespace
torch
::
indexing
;
// Note:
// Audio tensors contain multiple frames while video tensors contain only
// one frame. Video tensors can be regarded as special degenerated case of
// audio, so in the following, we only consider audio processing.
//
// The incoming Tensor might contain more frames than the value of
// `frames_per_chunk`.
// If we push the input tensor to dequeu as-is, then, at the trimming stage,
// the entire frames would be trimmed, this is not ideal. We want to keep
// at most `frames_per_chunk * num_chunks` frames.
// So we slice push the incoming Tensor.
//
// 1. Check if the last chunk is fully filled. If not, fill it.
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x + + + + + + - - | num_chunks
// - - - - - - - - - - - - - - - |
// <-- filled --><--- remain --->v
// <- append->
//
if
(
int64_t
filled
=
num_buffered_frames
%
frames_per_chunk
)
{
TORCH_INTERNAL_ASSERT
(
chunks
.
size
()
>
0
,
"There is supposed to be left over frames, but the buffer dequeue is empty."
);
int64_t
num_frames
=
frame
.
size
(
0
);
int64_t
remain
=
frames_per_chunk
-
filled
;
int64_t
append
=
remain
<
num_frames
?
remain
:
num_frames
;
torch
::
Tensor
prev
=
chunks
.
back
();
// prev[filled:filled+append] = frame[:append]
prev
.
index_put_
(
{
Slice
(
filled
,
filled
+
append
)},
frame
.
index
({
Slice
(
None
,
append
)}));
num_buffered_frames
+=
append
;
// frame = frame[append:]
frame
=
frame
.
index
({
Slice
(
append
)});
pts_
+=
append
;
}
// 2. Return if the number of input frames are smaller than the empty buffer.
// i.e. all the frames are pushed.
if
(
frame
.
numel
()
==
0
)
{
return
;
}
// 3. Now the existing buffer chunks are fully filled, start adding new chunks
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x x x x x x x x x | num_chunks
// + + + + + + + + + + + + + + + |
// <---------- append ---------->v
//
int64_t
num_frames
=
frame
.
size
(
0
);
int64_t
num_splits
=
num_frames
/
frames_per_chunk
+
(
num_frames
%
frames_per_chunk
?
1
:
0
);
for
(
int64_t
i
=
0
;
i
<
num_splits
;
++
i
)
{
int64_t
start
=
i
*
frames_per_chunk
;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto
chunk
=
frame
.
index
({
Slice
(
start
,
start
+
frames_per_chunk
)});
int64_t
pts_val
=
pts_
+
start
;
int64_t
chunk_size
=
chunk
.
size
(
0
);
TORCH_INTERNAL_ASSERT
(
chunk_size
<=
frames_per_chunk
,
"Chunk size is larger than frames per chunk."
);
if
(
chunk_size
<
frames_per_chunk
)
{
auto
shape
=
chunk
.
sizes
().
vec
();
shape
[
0
]
=
frames_per_chunk
;
auto
temp
=
torch
::
empty
(
shape
,
frame
.
options
());
temp
.
index_put_
({
Slice
(
None
,
chunk_size
)},
chunk
);
chunk
=
temp
;
}
chunks
.
push_back
(
chunk
);
pts
.
push_back
(
pts_val
);
num_buffered_frames
+=
chunk_size
;
// Trim if num_chunks > 0
if
(
num_chunks
>
0
&&
chunks
.
size
()
>
num_chunks
)
{
TORCH_WARN_ONCE
(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value."
);
chunks
.
pop_front
();
num_buffered_frames
-=
frames_per_chunk
;
}
}
}
c10
::
optional
<
Chunk
>
ChunkedBuffer
::
pop_chunk
()
{
using
namespace
torch
::
indexing
;
if
(
!
num_buffered_frames
)
{
return
{};
}
torch
::
Tensor
chunk
=
chunks
.
front
();
double
pts_val
=
double
(
pts
.
front
())
*
time_base
.
num
/
time_base
.
den
;
chunks
.
pop_front
();
pts
.
pop_front
();
if
(
num_buffered_frames
<
frames_per_chunk
)
{
chunk
=
chunk
.
index
({
Slice
(
None
,
num_buffered_frames
)});
}
num_buffered_frames
-=
chunk
.
size
(
0
);
return
{
Chunk
{
chunk
,
pts_val
}};
}
void
ChunkedBuffer
::
flush
()
{
num_buffered_frames
=
0
;
chunks
.
clear
();
}
}
// namespace torchaudio::io::detail
torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h
0 → 100644
View file @
ffeba11a
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace
torchaudio
::
io
::
detail
{
class
ChunkedBuffer
{
// Each AVFrame is converted to a Tensor and stored here.
std
::
deque
<
torch
::
Tensor
>
chunks
;
// Time stamps corresponding the first frame of each chunk
std
::
deque
<
int64_t
>
pts
;
AVRational
time_base
;
// The number of frames to return as a chunk
// If <0, then user wants to receive all the frames
const
int64_t
frames_per_chunk
;
// The numbe of chunks to retain
const
int64_t
num_chunks
;
// The number of currently stored chunks
// For video, one Tensor corresponds to one frame, but for audio,
// one Tensor contains multiple samples, so we track here.
int64_t
num_buffered_frames
=
0
;
public:
ChunkedBuffer
(
AVRational
time_base
,
int
frames_per_chunk
,
int
num_chunks
);
bool
is_ready
()
const
;
void
flush
();
c10
::
optional
<
Chunk
>
pop_chunk
();
void
push_frame
(
torch
::
Tensor
frame
,
int64_t
pts_
);
};
}
// namespace torchaudio::io::detail
torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.cpp
0 → 100644
View file @
ffeba11a
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
namespace
torchaudio
::
io
::
detail
{
UnchunkedBuffer
::
UnchunkedBuffer
(
AVRational
time_base
)
:
time_base
(
time_base
){};
bool
UnchunkedBuffer
::
is_ready
()
const
{
return
chunks
.
size
()
>
0
;
}
void
UnchunkedBuffer
::
push_frame
(
torch
::
Tensor
frame
,
int64_t
pts_
)
{
if
(
chunks
.
size
()
==
0
)
{
pts
=
double
(
pts_
)
*
time_base
.
num
/
time_base
.
den
;
}
chunks
.
push_back
(
frame
);
}
c10
::
optional
<
Chunk
>
UnchunkedBuffer
::
pop_chunk
()
{
if
(
chunks
.
size
()
==
0
)
{
return
{};
}
auto
frames
=
torch
::
cat
(
std
::
vector
<
torch
::
Tensor
>
{
chunks
.
begin
(),
chunks
.
end
()},
0
);
chunks
.
clear
();
return
{
Chunk
{
frames
,
pts
}};
}
void
UnchunkedBuffer
::
flush
()
{
chunks
.
clear
();
}
}
// namespace torchaudio::io::detail
torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h
0 → 100644
View file @
ffeba11a
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <deque>
namespace
torchaudio
::
io
::
detail
{
class
UnchunkedBuffer
{
// Each AVFrame is converted to a Tensor and stored here.
std
::
deque
<
torch
::
Tensor
>
chunks
;
double
pts
=
-
1.
;
AVRational
time_base
;
public:
UnchunkedBuffer
(
AVRational
time_base
);
bool
is_ready
()
const
;
void
push_frame
(
torch
::
Tensor
frame
,
int64_t
pts_
);
c10
::
optional
<
Chunk
>
pop_chunk
();
void
flush
();
};
}
// namespace torchaudio::io::detail
torchaudio/csrc/ffmpeg/stream_reader/conversion.cpp
0 → 100644
View file @
ffeba11a
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
#ifdef USE_CUDA
#include <c10/cuda/CUDAStream.h>
#endif
namespace
torchaudio
::
io
{
////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template
<
c10
::
ScalarType
dtype
,
bool
is_planar
>
AudioConverter
<
dtype
,
is_planar
>::
AudioConverter
(
int
c
)
:
num_channels
(
c
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
num_channels
>
0
);
}
template
<
c10
::
ScalarType
dtype
,
bool
is_planar
>
torch
::
Tensor
AudioConverter
<
dtype
,
is_planar
>::
convert
(
const
AVFrame
*
src
)
{
if
constexpr
(
is_planar
)
{
torch
::
Tensor
dst
=
torch
::
empty
({
num_channels
,
src
->
nb_samples
},
dtype
);
convert
(
src
,
dst
);
return
dst
.
permute
({
1
,
0
});
}
else
{
torch
::
Tensor
dst
=
torch
::
empty
({
src
->
nb_samples
,
num_channels
},
dtype
);
convert
(
src
,
dst
);
return
dst
;
}
}
// Converts AVFrame* into pre-allocated Tensor.
// The shape must be [C, T] if is_planar otherwise [T, C]
template
<
c10
::
ScalarType
dtype
,
bool
is_planar
>
void
AudioConverter
<
dtype
,
is_planar
>::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
num_channels
==
src
->
channels
);
constexpr
int
bps
=
[]()
{
switch
(
dtype
)
{
case
torch
::
kUInt8
:
return
1
;
case
torch
::
kInt16
:
return
2
;
case
torch
::
kInt32
:
case
torch
::
kFloat32
:
return
4
;
case
torch
::
kInt64
:
case
torch
::
kFloat64
:
return
8
;
}
}();
// Note
// FFMpeg's `nb_samples` represnts the number of samples par channel.
// whereas, in torchaudio, `num_samples` is used to represent the number of
// samples across channels. torchaudio uses `num_frames` for per-channel
// samples.
if
constexpr
(
is_planar
)
{
int
plane_size
=
bps
*
src
->
nb_samples
;
uint8_t
*
p_dst
=
static_cast
<
uint8_t
*>
(
dst
.
data_ptr
());
for
(
int
i
=
0
;
i
<
num_channels
;
++
i
)
{
memcpy
(
p_dst
,
src
->
extended_data
[
i
],
plane_size
);
p_dst
+=
plane_size
;
}
}
else
{
int
plane_size
=
bps
*
src
->
nb_samples
*
num_channels
;
memcpy
(
dst
.
data_ptr
(),
src
->
extended_data
[
0
],
plane_size
);
}
}
// Explicit instantiation
template
class
AudioConverter
<
torch
::
kUInt8
,
false
>;
template
class
AudioConverter
<
torch
::
kUInt8
,
true
>;
template
class
AudioConverter
<
torch
::
kInt16
,
false
>;
template
class
AudioConverter
<
torch
::
kInt16
,
true
>;
template
class
AudioConverter
<
torch
::
kInt32
,
false
>;
template
class
AudioConverter
<
torch
::
kInt32
,
true
>;
template
class
AudioConverter
<
torch
::
kInt64
,
false
>;
template
class
AudioConverter
<
torch
::
kInt64
,
true
>;
template
class
AudioConverter
<
torch
::
kFloat32
,
false
>;
template
class
AudioConverter
<
torch
::
kFloat32
,
true
>;
template
class
AudioConverter
<
torch
::
kFloat64
,
false
>;
template
class
AudioConverter
<
torch
::
kFloat64
,
true
>;
////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
namespace
{
torch
::
Tensor
get_image_buffer
(
at
::
IntArrayRef
shape
,
const
torch
::
Dtype
dtype
=
torch
::
kUInt8
)
{
return
torch
::
empty
(
shape
,
torch
::
TensorOptions
().
dtype
(
dtype
).
layout
(
torch
::
kStrided
));
}
torch
::
Tensor
get_image_buffer
(
at
::
IntArrayRef
shape
,
torch
::
Device
device
,
const
torch
::
Dtype
dtype
=
torch
::
kUInt8
)
{
return
torch
::
empty
(
shape
,
torch
::
TensorOptions
()
.
dtype
(
dtype
)
.
layout
(
torch
::
kStrided
)
.
device
(
device
));
}
}
// namespace
ImageConverterBase
::
ImageConverterBase
(
int
h
,
int
w
,
int
c
)
:
height
(
h
),
width
(
w
),
num_channels
(
c
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
height
>
0
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
width
>
0
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
num_channels
>
0
);
}
////////////////////////////////////////////////////////////////////////////////
// Interlaced Image
////////////////////////////////////////////////////////////////////////////////
void
InterlacedImageConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
height
);
int
stride
=
width
*
num_channels
;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
*
dst
.
size
(
3
)
==
stride
);
auto
p_dst
=
dst
.
data_ptr
<
uint8_t
>
();
uint8_t
*
p_src
=
src
->
data
[
0
];
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
memcpy
(
p_dst
,
p_src
,
stride
);
p_src
+=
src
->
linesize
[
0
];
p_dst
+=
stride
;
}
}
torch
::
Tensor
InterlacedImageConverter
::
convert
(
const
AVFrame
*
src
)
{
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
height
,
width
,
num_channels
});
convert
(
src
,
buffer
);
return
buffer
.
permute
({
0
,
3
,
1
,
2
});
}
////////////////////////////////////////////////////////////////////////////////
// Interlaced 16 Bit Image
////////////////////////////////////////////////////////////////////////////////
void
Interlaced16BitImageConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
height
);
int
stride
=
width
*
num_channels
;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
*
dst
.
size
(
3
)
==
stride
);
auto
p_dst
=
dst
.
data_ptr
<
int16_t
>
();
uint8_t
*
p_src
=
src
->
data
[
0
];
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
memcpy
(
p_dst
,
p_src
,
stride
*
2
);
p_src
+=
src
->
linesize
[
0
];
p_dst
+=
stride
;
}
// correct for int16
dst
+=
32768
;
}
torch
::
Tensor
Interlaced16BitImageConverter
::
convert
(
const
AVFrame
*
src
)
{
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
height
,
width
,
num_channels
},
torch
::
kInt16
);
convert
(
src
,
buffer
);
return
buffer
.
permute
({
0
,
3
,
1
,
2
});
}
////////////////////////////////////////////////////////////////////////////////
// Planar Image
////////////////////////////////////////////////////////////////////////////////
void
PlanarImageConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
num_channels
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
for
(
int
i
=
0
;
i
<
num_channels
;
++
i
)
{
torch
::
Tensor
plane
=
dst
.
index
({
0
,
i
});
uint8_t
*
p_dst
=
plane
.
data_ptr
<
uint8_t
>
();
uint8_t
*
p_src
=
src
->
data
[
i
];
int
linesize
=
src
->
linesize
[
i
];
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
memcpy
(
p_dst
,
p_src
,
width
);
p_src
+=
linesize
;
p_dst
+=
width
;
}
}
}
torch
::
Tensor
PlanarImageConverter
::
convert
(
const
AVFrame
*
src
)
{
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
num_channels
,
height
,
width
});
convert
(
src
,
buffer
);
return
buffer
;
}
////////////////////////////////////////////////////////////////////////////////
// YUV420P
////////////////////////////////////////////////////////////////////////////////
YUV420PConverter
::
YUV420PConverter
(
int
h
,
int
w
)
:
ImageConverterBase
(
h
,
w
,
3
)
{
TORCH_WARN_ONCE
(
"The output format YUV420P is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension."
);
}
void
YUV420PConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
(
AVPixelFormat
)(
src
->
format
)
==
AV_PIX_FMT_YUV420P
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
3
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
// Write Y plane directly
{
uint8_t
*
p_dst
=
dst
.
data_ptr
<
uint8_t
>
();
uint8_t
*
p_src
=
src
->
data
[
0
];
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
memcpy
(
p_dst
,
p_src
,
width
);
p_dst
+=
width
;
p_src
+=
src
->
linesize
[
0
];
}
}
// Chroma (U and V planes) are subsamapled by 2 in both vertical and
// holizontal directions.
// https://en.wikipedia.org/wiki/Chroma_subsampling
// Since we are returning data in Tensor, which has the same size for all
// color planes, we need to upsample the UV planes. PyTorch has interpolate
// function but it does not work for int16 type. So we manually copy them.
//
// block1 block2 block3 block4
// ab -> aabb = a b * a b * *
// cd aabb a b a b
// ccdd c d c d
// ccdd c d c d
//
auto
block00
=
dst
.
slice
(
2
,
0
,
{},
2
).
slice
(
3
,
0
,
{},
2
);
auto
block01
=
dst
.
slice
(
2
,
0
,
{},
2
).
slice
(
3
,
1
,
{},
2
);
auto
block10
=
dst
.
slice
(
2
,
1
,
{},
2
).
slice
(
3
,
0
,
{},
2
);
auto
block11
=
dst
.
slice
(
2
,
1
,
{},
2
).
slice
(
3
,
1
,
{},
2
);
for
(
int
i
=
1
;
i
<
3
;
++
i
)
{
// borrow data
auto
tmp
=
torch
::
from_blob
(
src
->
data
[
i
],
{
height
/
2
,
width
/
2
},
{
src
->
linesize
[
i
],
1
},
[](
void
*
)
{},
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
layout
(
torch
::
kStrided
));
// Copy to each block
block00
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
block01
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
block10
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
block11
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
}
}
torch
::
Tensor
YUV420PConverter
::
convert
(
const
AVFrame
*
src
)
{
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
num_channels
,
height
,
width
});
convert
(
src
,
buffer
);
return
buffer
;
}
////////////////////////////////////////////////////////////////////////////////
// YUV420P10LE
////////////////////////////////////////////////////////////////////////////////
YUV420P10LEConverter
::
YUV420P10LEConverter
(
int
h
,
int
w
)
:
ImageConverterBase
(
h
,
w
,
3
)
{
TORCH_WARN_ONCE
(
"The output format YUV420PLE is selected. "
"This will be implicitly converted to YUV444P (16-bit), "
"in which all the color components Y, U, V have the same dimension."
);
}
void
YUV420P10LEConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
(
AVPixelFormat
)(
src
->
format
)
==
AV_PIX_FMT_YUV420P10LE
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
3
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
dtype
()
==
torch
::
kInt16
);
// Write Y plane directly
{
int16_t
*
p_dst
=
dst
.
data_ptr
<
int16_t
>
();
uint8_t
*
p_src
=
src
->
data
[
0
];
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
memcpy
(
p_dst
,
p_src
,
(
size_t
)
width
*
2
);
p_dst
+=
width
;
p_src
+=
src
->
linesize
[
0
];
}
}
// Chroma (U and V planes) are subsamapled by 2 in both vertical and
// holizontal directions.
// https://en.wikipedia.org/wiki/Chroma_subsampling
// Since we are returning data in Tensor, which has the same size for all
// color planes, we need to upsample the UV planes. PyTorch has interpolate
// function but it does not work for int16 type. So we manually copy them.
//
// block1 block2 block3 block4
// ab -> aabb = a b * a b * *
// cd aabb a b a b
// ccdd c d c d
// ccdd c d c d
//
auto
block00
=
dst
.
slice
(
2
,
0
,
{},
2
).
slice
(
3
,
0
,
{},
2
);
auto
block01
=
dst
.
slice
(
2
,
0
,
{},
2
).
slice
(
3
,
1
,
{},
2
);
auto
block10
=
dst
.
slice
(
2
,
1
,
{},
2
).
slice
(
3
,
0
,
{},
2
);
auto
block11
=
dst
.
slice
(
2
,
1
,
{},
2
).
slice
(
3
,
1
,
{},
2
);
for
(
int
i
=
1
;
i
<
3
;
++
i
)
{
// borrow data
auto
tmp
=
torch
::
from_blob
(
src
->
data
[
i
],
{
height
/
2
,
width
/
2
},
{
src
->
linesize
[
i
]
/
2
,
1
},
[](
void
*
)
{},
torch
::
TensorOptions
().
dtype
(
torch
::
kInt16
).
layout
(
torch
::
kStrided
));
// Copy to each block
block00
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
block01
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
block10
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
block11
.
slice
(
1
,
i
,
i
+
1
).
copy_
(
tmp
);
}
}
torch
::
Tensor
YUV420P10LEConverter
::
convert
(
const
AVFrame
*
src
)
{
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
num_channels
,
height
,
width
},
torch
::
kInt16
);
convert
(
src
,
buffer
);
return
buffer
;
}
////////////////////////////////////////////////////////////////////////////////
// NV12
////////////////////////////////////////////////////////////////////////////////
NV12Converter
::
NV12Converter
(
int
h
,
int
w
)
:
ImageConverterBase
(
h
,
w
,
3
)
{
TORCH_WARN_ONCE
(
"The output format NV12 is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension."
);
}
void
NV12Converter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
(
AVPixelFormat
)(
src
->
format
)
==
AV_PIX_FMT_NV12
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
3
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
// Write Y plane directly
{
uint8_t
*
p_dst
=
dst
.
data_ptr
<
uint8_t
>
();
uint8_t
*
p_src
=
src
->
data
[
0
];
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
memcpy
(
p_dst
,
p_src
,
width
);
p_dst
+=
width
;
p_src
+=
src
->
linesize
[
0
];
}
}
// Write intermediate UV plane
{
auto
tmp
=
torch
::
from_blob
(
src
->
data
[
1
],
{
height
/
2
,
width
},
{
src
->
linesize
[
1
],
1
},
[](
void
*
)
{},
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
layout
(
torch
::
kStrided
));
tmp
=
tmp
.
view
({
1
,
height
/
2
,
width
/
2
,
2
}).
permute
({
0
,
3
,
1
,
2
});
auto
dst_uv
=
dst
.
slice
(
1
,
1
,
3
);
dst_uv
.
slice
(
2
,
0
,
{},
2
).
slice
(
3
,
0
,
{},
2
).
copy_
(
tmp
);
dst_uv
.
slice
(
2
,
0
,
{},
2
).
slice
(
3
,
1
,
{},
2
).
copy_
(
tmp
);
dst_uv
.
slice
(
2
,
1
,
{},
2
).
slice
(
3
,
0
,
{},
2
).
copy_
(
tmp
);
dst_uv
.
slice
(
2
,
1
,
{},
2
).
slice
(
3
,
1
,
{},
2
).
copy_
(
tmp
);
}
}
torch
::
Tensor
NV12Converter
::
convert
(
const
AVFrame
*
src
)
{
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
num_channels
,
height
,
width
});
convert
(
src
,
buffer
);
return
buffer
;
}
#ifdef USE_CUDA
CudaImageConverterBase
::
CudaImageConverterBase
(
const
torch
::
Device
&
device
)
:
device
(
device
)
{}
////////////////////////////////////////////////////////////////////////////////
// NV12 CUDA
////////////////////////////////////////////////////////////////////////////////
NV12CudaConverter
::
NV12CudaConverter
(
const
torch
::
Device
&
device
)
:
CudaImageConverterBase
(
device
)
{
TORCH_WARN_ONCE
(
"The output format NV12 is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension."
);
}
void
NV12CudaConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
3
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
dtype
()
==
torch
::
kUInt8
);
auto
fmt
=
(
AVPixelFormat
)(
src
->
format
);
AVHWFramesContext
*
hwctx
=
(
AVHWFramesContext
*
)
src
->
hw_frames_ctx
->
data
;
AVPixelFormat
sw_fmt
=
hwctx
->
sw_format
;
TORCH_INTERNAL_ASSERT
(
AV_PIX_FMT_CUDA
==
fmt
,
"Expected CUDA frame. Found: "
,
av_get_pix_fmt_name
(
fmt
));
TORCH_INTERNAL_ASSERT
(
AV_PIX_FMT_NV12
==
sw_fmt
,
"Expected NV12 format. Found: "
,
av_get_pix_fmt_name
(
sw_fmt
));
// Write Y plane directly
auto
status
=
cudaMemcpy2D
(
dst
.
data_ptr
(),
width
,
src
->
data
[
0
],
src
->
linesize
[
0
],
width
,
height
,
cudaMemcpyDeviceToDevice
);
TORCH_CHECK
(
cudaSuccess
==
status
,
"Failed to copy Y plane to Cuda tensor."
);
// Preapare intermediate UV planes
status
=
cudaMemcpy2D
(
tmp_uv
.
data_ptr
(),
width
,
src
->
data
[
1
],
src
->
linesize
[
1
],
width
,
height
/
2
,
cudaMemcpyDeviceToDevice
);
TORCH_CHECK
(
cudaSuccess
==
status
,
"Failed to copy UV plane to Cuda tensor."
);
// Upsample width and height
namespace
F
=
torch
::
nn
::
functional
;
torch
::
Tensor
uv
=
F
::
interpolate
(
tmp_uv
.
permute
({
0
,
3
,
1
,
2
}),
F
::
InterpolateFuncOptions
()
.
mode
(
torch
::
kNearest
)
.
size
(
std
::
vector
<
int64_t
>
({
height
,
width
})));
// Write to the UV plane
// dst[:, 1:] = uv
using
namespace
torch
::
indexing
;
dst
.
index_put_
({
Slice
(),
Slice
(
1
)},
uv
);
}
torch
::
Tensor
NV12CudaConverter
::
convert
(
const
AVFrame
*
src
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
if
(
!
init
)
{
height
=
src
->
height
;
width
=
src
->
width
;
tmp_uv
=
get_image_buffer
({
1
,
height
/
2
,
width
/
2
,
2
},
device
,
torch
::
kUInt8
);
init
=
true
;
}
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
3
,
height
,
width
},
device
);
convert
(
src
,
buffer
);
return
buffer
;
}
////////////////////////////////////////////////////////////////////////////////
// P010 CUDA
////////////////////////////////////////////////////////////////////////////////
P010CudaConverter
::
P010CudaConverter
(
const
torch
::
Device
&
device
)
:
CudaImageConverterBase
{
device
}
{
TORCH_WARN_ONCE
(
"The output format P010 is selected. "
"This will be implicitly converted to YUV444P, "
"in which all the color components Y, U, V have the same dimension."
);
}
void
P010CudaConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
3
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
dtype
()
==
torch
::
kInt16
);
auto
fmt
=
(
AVPixelFormat
)(
src
->
format
);
AVHWFramesContext
*
hwctx
=
(
AVHWFramesContext
*
)
src
->
hw_frames_ctx
->
data
;
AVPixelFormat
sw_fmt
=
hwctx
->
sw_format
;
TORCH_INTERNAL_ASSERT
(
AV_PIX_FMT_CUDA
==
fmt
,
"Expected CUDA frame. Found: "
,
av_get_pix_fmt_name
(
fmt
));
TORCH_INTERNAL_ASSERT
(
AV_PIX_FMT_P010
==
sw_fmt
,
"Expected P010 format. Found: "
,
av_get_pix_fmt_name
(
sw_fmt
));
// Write Y plane directly
auto
status
=
cudaMemcpy2D
(
dst
.
data_ptr
(),
width
*
2
,
src
->
data
[
0
],
src
->
linesize
[
0
],
width
*
2
,
height
,
cudaMemcpyDeviceToDevice
);
TORCH_CHECK
(
cudaSuccess
==
status
,
"Failed to copy Y plane to CUDA tensor."
);
// Prepare intermediate UV planes
status
=
cudaMemcpy2D
(
tmp_uv
.
data_ptr
(),
width
*
2
,
src
->
data
[
1
],
src
->
linesize
[
1
],
width
*
2
,
height
/
2
,
cudaMemcpyDeviceToDevice
);
TORCH_CHECK
(
cudaSuccess
==
status
,
"Failed to copy UV plane to CUDA tensor."
);
// Write to the UV plane
torch
::
Tensor
uv
=
tmp_uv
.
permute
({
0
,
3
,
1
,
2
});
using
namespace
torch
::
indexing
;
// very simplistic upscale using indexing since interpolate doesn't support
// shorts
dst
.
index_put_
(
{
Slice
(),
Slice
(
1
,
3
),
Slice
(
None
,
None
,
2
),
Slice
(
None
,
None
,
2
)},
uv
);
dst
.
index_put_
(
{
Slice
(),
Slice
(
1
,
3
),
Slice
(
1
,
None
,
2
),
Slice
(
None
,
None
,
2
)},
uv
);
dst
.
index_put_
(
{
Slice
(),
Slice
(
1
,
3
),
Slice
(
None
,
None
,
2
),
Slice
(
1
,
None
,
2
)},
uv
);
dst
.
index_put_
(
{
Slice
(),
Slice
(
1
,
3
),
Slice
(
1
,
None
,
2
),
Slice
(
1
,
None
,
2
)},
uv
);
// correct for int16
dst
+=
32768
;
}
torch
::
Tensor
P010CudaConverter
::
convert
(
const
AVFrame
*
src
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
if
(
!
init
)
{
height
=
src
->
height
;
width
=
src
->
width
;
tmp_uv
=
get_image_buffer
({
1
,
height
/
2
,
width
/
2
,
2
},
device
,
torch
::
kInt16
);
init
=
true
;
}
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
3
,
height
,
width
},
device
,
torch
::
kInt16
);
convert
(
src
,
buffer
);
return
buffer
;
}
////////////////////////////////////////////////////////////////////////////////
// YUV444P CUDA
////////////////////////////////////////////////////////////////////////////////
YUV444PCudaConverter
::
YUV444PCudaConverter
(
const
torch
::
Device
&
device
)
:
CudaImageConverterBase
(
device
)
{}
void
YUV444PCudaConverter
::
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
height
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
->
width
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
1
)
==
3
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
2
)
==
height
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
size
(
3
)
==
width
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
dst
.
dtype
()
==
torch
::
kUInt8
);
auto
fmt
=
(
AVPixelFormat
)(
src
->
format
);
AVHWFramesContext
*
hwctx
=
(
AVHWFramesContext
*
)
src
->
hw_frames_ctx
->
data
;
AVPixelFormat
sw_fmt
=
hwctx
->
sw_format
;
TORCH_INTERNAL_ASSERT
(
AV_PIX_FMT_CUDA
==
fmt
,
"Expected CUDA frame. Found: "
,
av_get_pix_fmt_name
(
fmt
));
TORCH_INTERNAL_ASSERT
(
AV_PIX_FMT_YUV444P
==
sw_fmt
,
"Expected YUV444P format. Found: "
,
av_get_pix_fmt_name
(
sw_fmt
));
// Write Y plane directly
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
auto
status
=
cudaMemcpy2D
(
dst
.
index
({
0
,
i
}).
data_ptr
(),
width
,
src
->
data
[
i
],
src
->
linesize
[
i
],
width
,
height
,
cudaMemcpyDeviceToDevice
);
TORCH_CHECK
(
cudaSuccess
==
status
,
"Failed to copy plane "
,
i
,
" to CUDA tensor."
);
}
}
torch
::
Tensor
YUV444PCudaConverter
::
convert
(
const
AVFrame
*
src
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
src
);
if
(
!
init
)
{
height
=
src
->
height
;
width
=
src
->
width
;
init
=
true
;
}
torch
::
Tensor
buffer
=
get_image_buffer
({
1
,
3
,
height
,
width
},
device
);
convert
(
src
,
buffer
);
return
buffer
;
}
#endif
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/conversion.h
0 → 100644
View file @
ffeba11a
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace
torchaudio
::
io
{
////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template
<
c10
::
ScalarType
dtype
,
bool
is_planar
>
class
AudioConverter
{
const
int
num_channels
;
public:
AudioConverter
(
int
num_channels
);
// Converts AVFrame* into Tensor of [T, C]
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
// Converts AVFrame* into pre-allocated Tensor.
// The shape must be [C, T] if is_planar otherwise [T, C]
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
};
////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
struct
ImageConverterBase
{
const
int
height
;
const
int
width
;
const
int
num_channels
;
ImageConverterBase
(
int
h
,
int
w
,
int
c
);
};
////////////////////////////////////////////////////////////////////////////////
// Interlaced Images - NHWC
////////////////////////////////////////////////////////////////////////////////
struct
InterlacedImageConverter
:
public
ImageConverterBase
{
using
ImageConverterBase
::
ImageConverterBase
;
// convert AVFrame* into Tensor of NCHW format
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
// convert AVFrame* into pre-allocated Tensor of NHWC format
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
};
struct
Interlaced16BitImageConverter
:
public
ImageConverterBase
{
using
ImageConverterBase
::
ImageConverterBase
;
// convert AVFrame* into Tensor of NCHW format
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
// convert AVFrame* into pre-allocated Tensor of NHWC format
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
};
////////////////////////////////////////////////////////////////////////////////
// Planar Images - NCHW
////////////////////////////////////////////////////////////////////////////////
struct
PlanarImageConverter
:
public
ImageConverterBase
{
using
ImageConverterBase
::
ImageConverterBase
;
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
////////////////////////////////////////////////////////////////////////////////
// Family of YUVs - NCHW
////////////////////////////////////////////////////////////////////////////////
class
YUV420PConverter
:
public
ImageConverterBase
{
public:
YUV420PConverter
(
int
height
,
int
width
);
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
class
YUV420P10LEConverter
:
public
ImageConverterBase
{
public:
YUV420P10LEConverter
(
int
height
,
int
width
);
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
class
NV12Converter
:
public
ImageConverterBase
{
public:
NV12Converter
(
int
height
,
int
width
);
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
#ifdef USE_CUDA
// Note:
// GPU decoders are tricky. They allow to change the resolution as part of
// decoder option, and the resulting resolution is (seemingly) not retrievable.
// Therefore, we adopt delayed frame size initialization.
// For that purpose, we do not inherit from ImageConverterBase.
struct
CudaImageConverterBase
{
const
torch
::
Device
device
;
bool
init
=
false
;
int
height
=
-
1
;
int
width
=
-
1
;
explicit
CudaImageConverterBase
(
const
torch
::
Device
&
device
);
};
class
NV12CudaConverter
:
CudaImageConverterBase
{
torch
::
Tensor
tmp_uv
{};
public:
explicit
NV12CudaConverter
(
const
torch
::
Device
&
device
);
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
class
P010CudaConverter
:
CudaImageConverterBase
{
torch
::
Tensor
tmp_uv
{};
public:
explicit
P010CudaConverter
(
const
torch
::
Device
&
device
);
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
class
YUV444PCudaConverter
:
CudaImageConverterBase
{
public:
explicit
YUV444PCudaConverter
(
const
torch
::
Device
&
device
);
void
convert
(
const
AVFrame
*
src
,
torch
::
Tensor
&
dst
);
torch
::
Tensor
convert
(
const
AVFrame
*
src
);
};
#endif
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.cpp
0 → 100644
View file @
ffeba11a
#include <torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h>
namespace
torchaudio
::
io
{
void
PacketBuffer
::
push_packet
(
AVPacket
*
packet
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
packet
,
"Packet is null."
);
AVPacket
*
p
=
av_packet_clone
(
packet
);
TORCH_INTERNAL_ASSERT
(
p
,
"Failed to clone packet."
);
packets
.
emplace_back
(
p
);
}
std
::
vector
<
AVPacketPtr
>
PacketBuffer
::
pop_packets
()
{
std
::
vector
<
AVPacketPtr
>
ret
{
std
::
make_move_iterator
(
packets
.
begin
()),
std
::
make_move_iterator
(
packets
.
end
())};
packets
.
clear
();
return
ret
;
}
bool
PacketBuffer
::
has_packets
()
{
return
packets
.
size
()
>
0
;
}
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/packet_buffer.h
0 → 100644
View file @
ffeba11a
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace
torchaudio
{
namespace
io
{
class
PacketBuffer
{
public:
void
push_packet
(
AVPacket
*
packet
);
std
::
vector
<
AVPacketPtr
>
pop_packets
();
bool
has_packets
();
private:
std
::
deque
<
AVPacketPtr
>
packets
;
};
}
// namespace io
}
// namespace torchaudio
torchaudio/csrc/ffmpeg/stream_reader/post_process.cpp
0 → 100644
View file @
ffeba11a
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/post_process.h>
namespace
torchaudio
::
io
{
namespace
detail
{
namespace
{
///////////////////////////////////////////////////////////////////////////////
// FilterGraphWrapper (FilterGraph + reset feature)
///////////////////////////////////////////////////////////////////////////////
using
FilterGraphFactory
=
std
::
function
<
FilterGraph
(
const
std
::
string
&
)
>
;
FilterGraphFactory
get_audio_factory
(
AVRational
time_base
,
AVCodecContext
*
codec_ctx
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
codec_ctx
->
codec_type
==
AVMEDIA_TYPE_AUDIO
);
return
[
fmt
=
codec_ctx
->
sample_fmt
,
time_base
,
rate
=
codec_ctx
->
sample_rate
,
channel_layout
=
codec_ctx
->
channel_layout
](
const
std
::
string
&
filter_desc
)
->
FilterGraph
{
FilterGraph
f
;
f
.
add_audio_src
(
fmt
,
time_base
,
rate
,
channel_layout
);
f
.
add_audio_sink
();
f
.
add_process
(
filter_desc
);
f
.
create_filter
();
return
f
;
};
}
FilterGraphFactory
get_video_factory
(
AVRational
time_base
,
AVRational
frame_rate
,
AVCodecContext
*
codec_ctx
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
codec_ctx
->
codec_type
==
AVMEDIA_TYPE_VIDEO
);
return
[
fmt
=
codec_ctx
->
pix_fmt
,
time_base
,
frame_rate
,
w
=
codec_ctx
->
width
,
h
=
codec_ctx
->
height
,
ratio
=
codec_ctx
->
sample_aspect_ratio
,
hw_frames_ctx
=
codec_ctx
->
hw_frames_ctx
](
const
std
::
string
&
filter_desc
)
->
FilterGraph
{
FilterGraph
f
;
f
.
add_video_src
(
fmt
,
time_base
,
frame_rate
,
w
,
h
,
ratio
);
f
.
add_video_sink
();
f
.
add_process
(
filter_desc
);
if
(
hw_frames_ctx
)
{
f
.
create_filter
(
av_buffer_ref
(
hw_frames_ctx
));
}
else
{
f
.
create_filter
();
}
return
f
;
};
}
struct
FilterGraphWrapper
{
const
std
::
string
desc
;
private:
FilterGraphFactory
factory
;
public:
FilterGraph
filter
;
// Constructor for audio input
FilterGraphWrapper
(
AVRational
input_time_base
,
AVCodecContext
*
codec_ctx
,
const
std
::
string
&
desc
)
:
desc
(
desc
),
factory
(
get_audio_factory
(
input_time_base
,
codec_ctx
)),
filter
(
factory
(
desc
))
{}
// Constructor for video input
FilterGraphWrapper
(
AVRational
input_time_base
,
AVRational
frame_rate
,
AVCodecContext
*
codec_ctx
,
const
std
::
string
&
desc
)
:
desc
(
desc
),
factory
(
get_video_factory
(
input_time_base
,
frame_rate
,
codec_ctx
)),
filter
(
factory
(
desc
))
{}
void
reset
()
{
filter
=
factory
(
desc
);
}
};
///////////////////////////////////////////////////////////////////////////////
// ProcessImpl
///////////////////////////////////////////////////////////////////////////////
template
<
typename
Converter
,
typename
Buffer
>
struct
ProcessImpl
:
public
IPostDecodeProcess
{
private:
AVFramePtr
frame
{
alloc_avframe
()};
FilterGraphWrapper
filter_wrapper
;
public:
Converter
converter
;
Buffer
buffer
;
ProcessImpl
(
FilterGraphWrapper
&&
filter_wrapper
,
Converter
&&
converter
,
Buffer
&&
buffer
)
:
filter_wrapper
(
std
::
move
(
filter_wrapper
)),
converter
(
std
::
move
(
converter
)),
buffer
(
std
::
move
(
buffer
))
{}
bool
is_buffer_ready
()
const
override
{
return
buffer
.
is_ready
();
}
const
std
::
string
&
get_filter_desc
()
const
override
{
return
filter_wrapper
.
desc
;
};
FilterGraphOutputInfo
get_filter_output_info
()
const
override
{
return
filter_wrapper
.
filter
.
get_output_info
();
};
void
flush
()
override
{
filter_wrapper
.
reset
();
buffer
.
flush
();
}
int
process_frame
(
AVFrame
*
in_frame
)
override
{
int
ret
=
filter_wrapper
.
filter
.
add_frame
(
in_frame
);
while
(
ret
>=
0
)
{
ret
=
filter_wrapper
.
filter
.
get_frame
(
frame
);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if
(
ret
==
AVERROR
(
EAGAIN
)
||
ret
==
AVERROR_EOF
)
{
return
0
;
}
if
(
ret
>=
0
)
{
buffer
.
push_frame
(
converter
.
convert
(
frame
),
frame
->
pts
);
}
av_frame_unref
(
frame
);
}
return
ret
;
}
c10
::
optional
<
Chunk
>
pop_chunk
()
override
{
return
buffer
.
pop_chunk
();
}
};
///////////////////////////////////////////////////////////////////////////////
// Audio
///////////////////////////////////////////////////////////////////////////////
std
::
unique_ptr
<
IPostDecodeProcess
>
get_unchunked_audio_process
(
FilterGraphWrapper
&&
filter
)
{
auto
i
=
filter
.
filter
.
get_output_info
();
TORCH_INTERNAL_ASSERT
(
i
.
type
==
AVMEDIA_TYPE_AUDIO
,
"Unsupported media type found: "
,
av_get_media_type_string
(
i
.
type
));
using
B
=
UnchunkedBuffer
;
switch
(
auto
fmt
=
(
AVSampleFormat
)
i
.
format
;
fmt
)
{
case
AV_SAMPLE_FMT_U8
:
{
using
C
=
AudioConverter
<
torch
::
kUInt8
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_S16
:
{
using
C
=
AudioConverter
<
torch
::
kInt16
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_S32
:
{
using
C
=
AudioConverter
<
torch
::
kInt32
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_S64
:
{
using
C
=
AudioConverter
<
torch
::
kInt64
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_FLT
:
{
using
C
=
AudioConverter
<
torch
::
kFloat32
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_DBL
:
{
using
C
=
AudioConverter
<
torch
::
kFloat64
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_U8P
:
{
using
C
=
AudioConverter
<
torch
::
kUInt8
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_S16P
:
{
using
C
=
AudioConverter
<
torch
::
kInt16
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_S32P
:
{
using
C
=
AudioConverter
<
torch
::
kInt32
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_S64P
:
{
using
C
=
AudioConverter
<
torch
::
kInt64
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_FLTP
:
{
using
C
=
AudioConverter
<
torch
::
kFloat32
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
case
AV_SAMPLE_FMT_DBLP
:
{
using
C
=
AudioConverter
<
torch
::
kFloat64
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
B
{
i
.
time_base
});
}
default:
TORCH_INTERNAL_ASSERT
(
false
,
"Unexpected audio type:"
,
av_get_sample_fmt_name
(
fmt
));
}
}
std
::
unique_ptr
<
IPostDecodeProcess
>
get_chunked_audio_process
(
FilterGraphWrapper
&&
filter
,
int
frames_per_chunk
,
int
num_chunks
)
{
auto
i
=
filter
.
filter
.
get_output_info
();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
i
.
type
==
AVMEDIA_TYPE_AUDIO
,
"Unsupported media type found: "
,
av_get_media_type_string
(
i
.
type
));
using
B
=
ChunkedBuffer
;
B
buffer
{
i
.
time_base
,
frames_per_chunk
,
num_chunks
};
switch
(
auto
fmt
=
(
AVSampleFormat
)
i
.
format
;
fmt
)
{
case
AV_SAMPLE_FMT_U8
:
{
using
C
=
AudioConverter
<
torch
::
kUInt8
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_S16
:
{
using
C
=
AudioConverter
<
torch
::
kInt16
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_S32
:
{
using
C
=
AudioConverter
<
torch
::
kInt32
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_S64
:
{
using
C
=
AudioConverter
<
torch
::
kInt64
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_FLT
:
{
using
C
=
AudioConverter
<
torch
::
kFloat32
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_DBL
:
{
using
C
=
AudioConverter
<
torch
::
kFloat64
,
false
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_U8P
:
{
using
C
=
AudioConverter
<
torch
::
kUInt8
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_S16P
:
{
using
C
=
AudioConverter
<
torch
::
kInt16
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_S32P
:
{
using
C
=
AudioConverter
<
torch
::
kInt32
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_S64P
:
{
using
C
=
AudioConverter
<
torch
::
kInt64
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_FLTP
:
{
using
C
=
AudioConverter
<
torch
::
kFloat32
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
case
AV_SAMPLE_FMT_DBLP
:
{
using
C
=
AudioConverter
<
torch
::
kFloat64
,
true
>
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
i
.
num_channels
},
std
::
move
(
buffer
));
}
default:
TORCH_INTERNAL_ASSERT
(
false
,
"Unexpected audio type:"
,
av_get_sample_fmt_name
(
fmt
));
}
}
///////////////////////////////////////////////////////////////////////////////
// Video
///////////////////////////////////////////////////////////////////////////////
std
::
unique_ptr
<
IPostDecodeProcess
>
get_unchunked_video_process
(
FilterGraphWrapper
&&
filter
)
{
auto
i
=
filter
.
filter
.
get_output_info
();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
i
.
type
==
AVMEDIA_TYPE_VIDEO
,
"Unsupported media type found: "
,
av_get_media_type_string
(
i
.
type
));
auto
h
=
i
.
height
;
auto
w
=
i
.
width
;
auto
tb
=
i
.
time_base
;
using
B
=
UnchunkedBuffer
;
switch
(
auto
fmt
=
(
AVPixelFormat
)
i
.
format
;
fmt
)
{
case
AV_PIX_FMT_RGB24
:
case
AV_PIX_FMT_BGR24
:
{
using
C
=
InterlacedImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
3
},
B
{
tb
});
}
case
AV_PIX_FMT_ARGB
:
case
AV_PIX_FMT_RGBA
:
case
AV_PIX_FMT_ABGR
:
case
AV_PIX_FMT_BGRA
:
{
using
C
=
InterlacedImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
4
},
B
{
tb
});
}
case
AV_PIX_FMT_GRAY8
:
{
using
C
=
InterlacedImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
1
},
B
{
tb
});
}
case
AV_PIX_FMT_RGB48LE
:
{
using
C
=
Interlaced16BitImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
3
},
B
{
tb
});
}
case
AV_PIX_FMT_YUV444P
:
{
using
C
=
PlanarImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
3
},
B
{
tb
});
}
case
AV_PIX_FMT_YUV420P
:
{
using
C
=
YUV420PConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
},
B
{
tb
});
}
case
AV_PIX_FMT_YUV420P10LE
:
{
using
C
=
YUV420P10LEConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
},
B
{
tb
});
}
case
AV_PIX_FMT_NV12
:
{
using
C
=
NV12Converter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
},
B
{
tb
});
}
default:
{
TORCH_INTERNAL_ASSERT
(
false
,
"Unexpected video format found: "
,
av_get_pix_fmt_name
(
fmt
));
}
}
}
std
::
unique_ptr
<
IPostDecodeProcess
>
get_unchunked_cuda_video_process
(
FilterGraphWrapper
&&
filter
,
const
torch
::
Device
&
device
)
{
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT
(
false
,
"USE_CUDA is not defined, but CUDA decoding process was requested."
);
#else
auto
i
=
filter
.
filter
.
get_output_info
();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
i
.
type
==
AVMEDIA_TYPE_VIDEO
,
"Unsupported media type found: "
,
av_get_media_type_string
(
i
.
type
));
using
B
=
UnchunkedBuffer
;
switch
(
auto
fmt
=
(
AVPixelFormat
)
i
.
format
;
fmt
)
{
case
AV_PIX_FMT_NV12
:
{
using
C
=
NV12CudaConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
device
},
B
{
i
.
time_base
});
}
case
AV_PIX_FMT_P010
:
{
using
C
=
P010CudaConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
device
},
B
{
i
.
time_base
});
}
case
AV_PIX_FMT_YUV444P
:
{
using
C
=
YUV444PCudaConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
device
},
B
{
i
.
time_base
});
}
case
AV_PIX_FMT_P016
:
{
TORCH_CHECK
(
false
,
"Unsupported video format found in CUDA HW: "
,
av_get_pix_fmt_name
(
fmt
));
}
default:
{
TORCH_CHECK
(
false
,
"Unexpected video format found in CUDA HW: "
,
av_get_pix_fmt_name
(
fmt
));
}
}
#endif
}
std
::
unique_ptr
<
IPostDecodeProcess
>
get_chunked_video_process
(
FilterGraphWrapper
&&
filter
,
int
frames_per_chunk
,
int
num_chunks
)
{
auto
i
=
filter
.
filter
.
get_output_info
();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
i
.
type
==
AVMEDIA_TYPE_VIDEO
,
"Unsupported media type found: "
,
av_get_media_type_string
(
i
.
type
));
auto
h
=
i
.
height
;
auto
w
=
i
.
width
;
auto
tb
=
i
.
time_base
;
using
B
=
ChunkedBuffer
;
switch
(
auto
fmt
=
(
AVPixelFormat
)
i
.
format
;
fmt
)
{
case
AV_PIX_FMT_RGB24
:
case
AV_PIX_FMT_BGR24
:
{
using
C
=
InterlacedImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
3
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_ARGB
:
case
AV_PIX_FMT_RGBA
:
case
AV_PIX_FMT_ABGR
:
case
AV_PIX_FMT_BGRA
:
{
using
C
=
InterlacedImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
4
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_GRAY8
:
{
using
C
=
InterlacedImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
1
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_RGB48LE
:
{
using
C
=
Interlaced16BitImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
3
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_YUV444P
:
{
using
C
=
PlanarImageConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
,
3
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_YUV420P
:
{
using
C
=
YUV420PConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_YUV420P10LE
:
{
using
C
=
YUV420P10LEConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_NV12
:
{
using
C
=
NV12Converter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
h
,
w
},
B
{
tb
,
frames_per_chunk
,
num_chunks
});
}
default:
{
TORCH_INTERNAL_ASSERT
(
false
,
"Unexpected video format found: "
,
av_get_pix_fmt_name
(
fmt
));
}
}
}
std
::
unique_ptr
<
IPostDecodeProcess
>
get_chunked_cuda_video_process
(
FilterGraphWrapper
&&
filter
,
int
frames_per_chunk
,
int
num_chunks
,
const
torch
::
Device
&
device
)
{
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT
(
false
,
"USE_CUDA is not defined, but CUDA decoding process was requested."
);
#else
auto
i
=
filter
.
filter
.
get_output_info
();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
i
.
type
==
AVMEDIA_TYPE_VIDEO
,
"Unsupported media type found: "
,
av_get_media_type_string
(
i
.
type
));
using
B
=
ChunkedBuffer
;
switch
(
auto
fmt
=
(
AVPixelFormat
)
i
.
format
;
fmt
)
{
case
AV_PIX_FMT_NV12
:
{
using
C
=
NV12CudaConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
device
},
B
{
i
.
time_base
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_P010
:
{
using
C
=
P010CudaConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
device
},
B
{
i
.
time_base
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_YUV444P
:
{
using
C
=
YUV444PCudaConverter
;
return
std
::
make_unique
<
ProcessImpl
<
C
,
B
>>
(
std
::
move
(
filter
),
C
{
device
},
B
{
i
.
time_base
,
frames_per_chunk
,
num_chunks
});
}
case
AV_PIX_FMT_P016
:
{
TORCH_CHECK
(
false
,
"Unsupported video format found in CUDA HW: "
,
av_get_pix_fmt_name
(
fmt
));
}
default:
{
TORCH_CHECK
(
false
,
"Unexpected video format found in CUDA HW: "
,
av_get_pix_fmt_name
(
fmt
));
}
}
#endif
}
}
// namespace
}
// namespace detail
std
::
unique_ptr
<
IPostDecodeProcess
>
get_audio_process
(
AVRational
input_time_base
,
AVCodecContext
*
codec_ctx
,
const
std
::
string
&
desc
,
int
frames_per_chunk
,
int
num_chunks
)
{
TORCH_CHECK
(
frames_per_chunk
>
0
||
frames_per_chunk
==
-
1
,
"`frames_per_chunk` must be positive or -1. Found: "
,
frames_per_chunk
);
TORCH_CHECK
(
num_chunks
>
0
||
num_chunks
==
-
1
,
"`num_chunks` must be positive or -1. Found: "
,
num_chunks
);
detail
::
FilterGraphWrapper
filter
{
input_time_base
,
codec_ctx
,
desc
};
if
(
frames_per_chunk
==
-
1
)
{
return
detail
::
get_unchunked_audio_process
(
std
::
move
(
filter
));
}
return
detail
::
get_chunked_audio_process
(
std
::
move
(
filter
),
frames_per_chunk
,
num_chunks
);
}
std
::
unique_ptr
<
IPostDecodeProcess
>
get_video_process
(
AVRational
input_time_base
,
AVRational
frame_rate
,
AVCodecContext
*
codec_ctx
,
const
std
::
string
&
desc
,
int
frames_per_chunk
,
int
num_chunks
,
const
torch
::
Device
&
device
)
{
TORCH_CHECK
(
frames_per_chunk
>
0
||
frames_per_chunk
==
-
1
,
"`frames_per_chunk` must be positive or -1. Found: "
,
frames_per_chunk
);
TORCH_CHECK
(
num_chunks
>
0
||
num_chunks
==
-
1
,
"`num_chunks` must be positive or -1. Found: "
,
num_chunks
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
device
.
is_cuda
()
||
device
.
is_cpu
(),
"Unexpected device type: "
,
device
);
detail
::
FilterGraphWrapper
filter
{
input_time_base
,
frame_rate
,
codec_ctx
,
desc
};
if
(
frames_per_chunk
==
-
1
)
{
if
(
device
.
is_cuda
())
{
return
detail
::
get_unchunked_cuda_video_process
(
std
::
move
(
filter
),
device
);
}
return
detail
::
get_unchunked_video_process
(
std
::
move
(
filter
));
}
if
(
device
.
is_cuda
())
{
return
detail
::
get_chunked_cuda_video_process
(
std
::
move
(
filter
),
frames_per_chunk
,
num_chunks
,
device
);
}
return
detail
::
get_chunked_video_process
(
std
::
move
(
filter
),
frames_per_chunk
,
num_chunks
);
}
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/post_process.h
0 → 100644
View file @
ffeba11a
#pragma once
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace
torchaudio
::
io
{
struct
IPostDecodeProcess
{
virtual
~
IPostDecodeProcess
()
=
default
;
virtual
int
process_frame
(
AVFrame
*
frame
)
=
0
;
virtual
c10
::
optional
<
Chunk
>
pop_chunk
()
=
0
;
virtual
bool
is_buffer_ready
()
const
=
0
;
virtual
const
std
::
string
&
get_filter_desc
()
const
=
0
;
virtual
FilterGraphOutputInfo
get_filter_output_info
()
const
=
0
;
virtual
void
flush
()
=
0
;
};
std
::
unique_ptr
<
IPostDecodeProcess
>
get_audio_process
(
AVRational
input_time_base
,
AVCodecContext
*
codec_ctx
,
const
std
::
string
&
desc
,
int
frames_per_chunk
,
int
num_chunks
);
std
::
unique_ptr
<
IPostDecodeProcess
>
get_video_process
(
AVRational
input_time_base
,
AVRational
frame_rate
,
AVCodecContext
*
codec_ctx
,
const
std
::
string
&
desc
,
int
frames_per_chunk
,
int
num_chunks
,
const
torch
::
Device
&
device
);
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp
View file @
ffeba11a
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h>
#include <stdexcept>
#include <string_view>
namespace
torchaudio
{
namespace
ffmpeg
{
namespace
torchaudio
::
io
{
using
KeyType
=
StreamProcessor
::
KeyType
;
namespace
{
AVCodecContextPtr
alloc_codec_context
(
enum
AVCodecID
codec_id
,
const
c10
::
optional
<
std
::
string
>&
decoder_name
)
{
const
AVCodec
*
codec
=
[
&
]()
{
if
(
decoder_name
)
{
const
AVCodec
*
c
=
avcodec_find_decoder_by_name
(
decoder_name
.
value
().
c_str
());
TORCH_CHECK
(
c
,
"Unsupported codec: "
,
decoder_name
.
value
());
return
c
;
}
else
{
const
AVCodec
*
c
=
avcodec_find_decoder
(
codec_id
);
TORCH_CHECK
(
c
,
"Unsupported codec: "
,
avcodec_get_name
(
codec_id
));
return
c
;
}
}();
AVCodecContext
*
codec_ctx
=
avcodec_alloc_context3
(
codec
);
TORCH_CHECK
(
codec_ctx
,
"Failed to allocate CodecContext."
);
return
AVCodecContextPtr
(
codec_ctx
);
}
const
AVCodecHWConfig
*
get_cuda_config
(
const
AVCodec
*
codec
)
{
for
(
int
i
=
0
;;
++
i
)
{
const
AVCodecHWConfig
*
config
=
avcodec_get_hw_config
(
codec
,
i
);
if
(
!
config
)
{
break
;
}
if
(
config
->
device_type
==
AV_HWDEVICE_TYPE_CUDA
&&
config
->
methods
&
AV_CODEC_HW_CONFIG_METHOD_HW_DEVICE_CTX
)
{
return
config
;
}
}
TORCH_CHECK
(
false
,
"CUDA device was requested, but the codec
\"
"
,
codec
->
name
,
"
\"
is not supported."
);
}
enum
AVPixelFormat
get_hw_format
(
AVCodecContext
*
codec_ctx
,
const
enum
AVPixelFormat
*
pix_fmts
)
{
const
AVCodecHWConfig
*
cfg
=
static_cast
<
AVCodecHWConfig
*>
(
codec_ctx
->
opaque
);
for
(
const
enum
AVPixelFormat
*
p
=
pix_fmts
;
*
p
!=
-
1
;
p
++
)
{
if
(
*
p
==
cfg
->
pix_fmt
)
{
// Note
// The HW decode example uses generic approach
// https://ffmpeg.org/doxygen/4.1/hw__decode_8c_source.html#l00063
// But this approach finalizes the codec configuration when the first
// frame comes in.
// We need to inspect the codec configuration right after the codec is
// opened.
// So we add short cut for known patterns.
// yuv420p (h264) -> nv12
// yuv420p10le (hevc/h265) -> p010le
switch
(
codec_ctx
->
pix_fmt
)
{
case
AV_PIX_FMT_YUV420P
:
{
codec_ctx
->
pix_fmt
=
AV_PIX_FMT_CUDA
;
codec_ctx
->
sw_pix_fmt
=
AV_PIX_FMT_NV12
;
break
;
}
case
AV_PIX_FMT_YUV420P10LE
:
{
codec_ctx
->
pix_fmt
=
AV_PIX_FMT_CUDA
;
codec_ctx
->
sw_pix_fmt
=
AV_PIX_FMT_P010LE
;
break
;
}
default:
;
}
return
*
p
;
}
}
TORCH_WARN
(
"Failed to get HW surface format."
);
return
AV_PIX_FMT_NONE
;
}
AVBufferRef
*
get_hw_frames_ctx
(
AVCodecContext
*
codec_ctx
)
{
AVBufferRef
*
p
=
av_hwframe_ctx_alloc
(
codec_ctx
->
hw_device_ctx
);
TORCH_CHECK
(
p
,
"Failed to allocate CUDA frame context from device context at "
,
codec_ctx
->
hw_device_ctx
);
auto
frames_ctx
=
(
AVHWFramesContext
*
)(
p
->
data
);
frames_ctx
->
format
=
codec_ctx
->
pix_fmt
;
frames_ctx
->
sw_format
=
codec_ctx
->
sw_pix_fmt
;
frames_ctx
->
width
=
codec_ctx
->
width
;
frames_ctx
->
height
=
codec_ctx
->
height
;
frames_ctx
->
initial_pool_size
=
5
;
int
ret
=
av_hwframe_ctx_init
(
p
);
if
(
ret
>=
0
)
{
return
p
;
}
av_buffer_unref
(
&
p
);
TORCH_CHECK
(
false
,
"Failed to initialize CUDA frame context: "
,
av_err2string
(
ret
));
}
void
configure_codec_context
(
AVCodecContext
*
codec_ctx
,
const
AVCodecParameters
*
params
,
const
torch
::
Device
&
device
)
{
int
ret
=
avcodec_parameters_to_context
(
codec_ctx
,
params
);
TORCH_CHECK
(
ret
>=
0
,
"Failed to set CodecContext parameter: "
,
av_err2string
(
ret
));
if
(
device
.
type
()
==
c10
::
DeviceType
::
CUDA
)
{
#ifndef USE_CUDA
TORCH_CHECK
(
false
,
"torchaudio is not compiled with CUDA support."
);
#else
const
AVCodecHWConfig
*
cfg
=
get_cuda_config
(
codec_ctx
->
codec
);
// https://www.ffmpeg.org/doxygen/trunk/hw__decode_8c_source.html#l00221
// 1. Set HW config to opaue pointer.
codec_ctx
->
opaque
=
static_cast
<
void
*>
(
const_cast
<
AVCodecHWConfig
*>
(
cfg
));
// 2. Set pCodecContext->get_format call back function which
// will retrieve the HW pixel format from opaque pointer.
codec_ctx
->
get_format
=
get_hw_format
;
codec_ctx
->
hw_device_ctx
=
av_buffer_ref
(
get_cuda_context
(
device
.
index
()));
TORCH_INTERNAL_ASSERT
(
codec_ctx
->
hw_device_ctx
,
"Failed to reference HW device context."
);
#endif
}
}
void
open_codec
(
AVCodecContext
*
codec_ctx
,
const
c10
::
optional
<
OptionDict
>&
decoder_option
)
{
AVDictionary
*
opts
=
get_option_dict
(
decoder_option
);
// Default to single thread execution.
if
(
!
av_dict_get
(
opts
,
"threads"
,
nullptr
,
0
))
{
av_dict_set
(
&
opts
,
"threads"
,
"1"
,
0
);
}
if
(
!
codec_ctx
->
channel_layout
)
{
codec_ctx
->
channel_layout
=
av_get_default_channel_layout
(
codec_ctx
->
channels
);
}
int
ret
=
avcodec_open2
(
codec_ctx
,
codec_ctx
->
codec
,
&
opts
);
clean_up_dict
(
opts
);
TORCH_CHECK
(
ret
>=
0
,
"Failed to initialize CodecContext: "
,
av_err2string
(
ret
));
}
bool
ends_with
(
std
::
string_view
str
,
std
::
string_view
suffix
)
{
return
str
.
size
()
>=
suffix
.
size
()
&&
0
==
str
.
compare
(
str
.
size
()
-
suffix
.
size
(),
suffix
.
size
(),
suffix
);
}
StreamProcessor
::
StreamProcessor
(
AVCodecParameters
*
codec
par
,
AVCodecContextPtr
get_codec_ctx
(
const
AVCodecParameters
*
par
ams
,
const
c10
::
optional
<
std
::
string
>&
decoder_name
,
const
c10
::
optional
<
OptionDict
>&
decoder_option
,
const
torch
::
Device
&
device
)
:
decoder
(
codecpar
,
decoder_name
,
decoder_option
,
device
)
{}
const
torch
::
Device
&
device
)
{
AVCodecContextPtr
codec_ctx
=
alloc_codec_context
(
params
->
codec_id
,
decoder_name
);
configure_codec_context
(
codec_ctx
,
params
,
device
);
open_codec
(
codec_ctx
,
decoder_option
);
if
(
codec_ctx
->
hw_device_ctx
)
{
codec_ctx
->
hw_frames_ctx
=
get_hw_frames_ctx
(
codec_ctx
);
}
if
(
ends_with
(
codec_ctx
->
codec
->
name
,
"_cuvid"
))
{
C10_LOG_API_USAGE_ONCE
(
"torchaudio.io.StreamReaderCUDA"
);
}
return
codec_ctx
;
}
}
// namespace
using
KeyType
=
StreamProcessor
::
KeyType
;
StreamProcessor
::
StreamProcessor
(
const
AVRational
&
time_base
)
:
stream_time_base
(
time_base
)
{}
////////////////////////////////////////////////////////////////////////////////
// Configurations
////////////////////////////////////////////////////////////////////////////////
KeyType
StreamProcessor
::
add_stream
(
AVRational
input_time_base
,
AVCodecParameters
*
codecpar
,
int
frames_per_chunk
,
int
num_chunks
,
const
c10
::
optional
<
std
::
string
>&
filter_description
,
AVRational
frame_rate
,
const
std
::
string
&
filter_description
,
const
torch
::
Device
&
device
)
{
switch
(
codecpar
->
codec_type
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
is_decoder_set
(),
"Decoder hasn't been set."
);
// If device is provided, then check that codec_ctx has hw_device_ctx set.
// In case, defining an output stream with HW accel on an input stream that
// has decoder set without HW accel, it will cause seg fault.
// i.e.
// The following should be rejected here.
// reader = StreamReader(...)
// reader.add_video_stream(..., decoder="h264_cuvid")
// reader.add_video_stream(..., decoder="h264_cuvid", hw_accel="cuda")
// TODO:
// One idea to work around this is to always define HW device context, and
// if HW acceleration is not required, insert `hwdownload` filter.
// This way it will be possible to handle both cases at the same time.
switch
(
device
.
type
())
{
case
torch
::
kCPU
:
TORCH_CHECK
(
!
codec_ctx
->
hw_device_ctx
,
"Decoding without Hardware acceleration is requested, however, "
"the decoder has been already defined with a HW acceleration. "
"Decoding a stream with and without HW acceleration simultaneously "
"is not supported."
);
break
;
case
torch
::
kCUDA
:
TORCH_CHECK
(
codec_ctx
->
hw_device_ctx
,
"CUDA Hardware acceleration is requested, however, the decoder has "
"been already defined without a HW acceleration. "
"Decoding a stream with and without HW acceleration simultaneously "
"is not supported."
);
break
;
default:
;
}
switch
(
codec_ctx
->
codec_type
)
{
case
AVMEDIA_TYPE_AUDIO
:
post_processes
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
current_key
),
std
::
forward_as_tuple
(
get_audio_process
(
stream_time_base
,
codec_ctx
,
filter_description
,
frames_per_chunk
,
num_chunks
)));
return
current_key
++
;
case
AVMEDIA_TYPE_VIDEO
:
break
;
post_processes
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
current_key
),
std
::
forward_as_tuple
(
get_video_process
(
stream_time_base
,
frame_rate
,
codec_ctx
,
filter_description
,
frames_per_chunk
,
num_chunks
,
device
)));
return
current_key
++
;
default:
TORCH_CHECK
(
false
,
"Only Audio and Video are supported"
);
}
KeyType
key
=
current_key
++
;
sinks
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
key
),
std
::
forward_as_tuple
(
input_time_base
,
codecpar
,
frames_per_chunk
,
num_chunks
,
filter_description
,
device
));
decoder_time_base
=
av_q2d
(
input_time_base
);
return
key
;
}
void
StreamProcessor
::
remove_stream
(
KeyType
key
)
{
sinks
.
erase
(
key
);
post_processes
.
erase
(
key
);
}
void
StreamProcessor
::
set_discard_timestamp
(
int64_t
timestamp
)
{
TORCH_CHECK
(
timestamp
>=
0
,
"timestamp must be non-negative."
);
discard_before_pts
=
av_rescale_q
(
timestamp
,
av_get_time_base_q
(),
stream_time_base
);
}
void
StreamProcessor
::
set_decoder
(
const
AVCodecParameters
*
codecpar
,
const
c10
::
optional
<
std
::
string
>&
decoder_name
,
const
c10
::
optional
<
OptionDict
>&
decoder_option
,
const
torch
::
Device
&
device
)
{
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
!
codec_ctx
,
"Decoder has already been set."
);
codec_ctx
=
get_codec_ctx
(
codecpar
,
decoder_name
,
decoder_option
,
device
);
}
////////////////////////////////////////////////////////////////////////////////
// Query methods
////////////////////////////////////////////////////////////////////////////////
std
::
string
StreamProcessor
::
get_filter_description
(
KeyType
key
)
const
{
return
sinks
.
at
(
key
).
get_filter_description
();
return
post_processes
.
at
(
key
)
->
get_filter_desc
();
}
FilterGraphOutputInfo
StreamProcessor
::
get_filter_output_info
(
KeyType
key
)
const
{
return
post_processes
.
at
(
key
)
->
get_filter_output_info
();
}
bool
StreamProcessor
::
is_buffer_ready
()
const
{
for
(
const
auto
&
it
:
sink
s
)
{
if
(
!
it
.
second
.
is_buffer_ready
())
{
for
(
const
auto
&
it
:
post_processe
s
)
{
if
(
!
it
.
second
->
is_buffer_ready
())
{
return
false
;
}
}
return
true
;
}
bool
StreamProcessor
::
is_decoder_set
()
const
{
return
codec_ctx
;
}
////////////////////////////////////////////////////////////////////////////////
// The streaming process
////////////////////////////////////////////////////////////////////////////////
// 0: some kind of success
// <0: Some error happened
int
StreamProcessor
::
process_packet
(
AVPacket
*
packet
)
{
int
ret
=
decoder
.
process_packet
(
packet
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
is_decoder_set
(),
"Decoder must have been set prior to calling this function."
);
int
ret
=
avcodec_send_packet
(
codec_ctx
,
packet
);
while
(
ret
>=
0
)
{
ret
=
de
code
r
.
get_frame
(
pF
rame
1
);
ret
=
av
code
c_receive_frame
(
codec_ctx
,
f
rame
);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if
(
ret
==
AVERROR
(
EAGAIN
))
return
0
;
if
(
ret
==
AVERROR_EOF
)
return
send_frame
(
NULL
);
return
send_frame
(
nullptr
);
if
(
ret
<
0
)
return
ret
;
send_frame
(
pFrame1
);
av_frame_unref
(
pFrame1
);
// If pts is undefined then overwrite with best effort estimate.
// In this case, best_effort_timestamp is basically the number of frames
// emit from decoder.
//
// We need valid pts because filter_graph does not fall back to
// best_effort_timestamp.
if
(
frame
->
pts
==
AV_NOPTS_VALUE
)
{
if
(
frame
->
best_effort_timestamp
==
AV_NOPTS_VALUE
)
{
// This happens in drain mode.
// When the decoder enters drain mode, it starts flushing the internally
// buffered frames, of which PTS cannot be estimated.
//
// This is because they might be intra-frames not in chronological
// order. In this case, we use received frames as-is in the order they
// are received.
frame
->
pts
=
codec_ctx
->
frame_number
+
1
;
}
else
{
frame
->
pts
=
frame
->
best_effort_timestamp
;
}
}
// When the value of discard_before_pts is 0, we consider that the seek is
// not performed and all the frames are passed to downstream
// unconditionally.
//
// Two reasons for this behavior;
// 1. When seek mode is not precise, we do not discard any frame.
// In this case discard_before_pts is set to zero.
// 2. When users seek to zero, what they expect is to get to the beginning
// of the data.
//
// Note: discard_before_pts < 0 is UB.
if
(
discard_before_pts
<=
0
||
frame
->
pts
>=
discard_before_pts
)
{
send_frame
(
frame
);
}
// else we can just unref the frame and continue
av_frame_unref
(
frame
);
}
return
ret
;
}
void
StreamProcessor
::
flush
()
{
decoder
.
flush_buffer
();
for
(
auto
&
ite
:
sinks
)
{
ite
.
second
.
flush
();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY
(
is_decoder_set
(),
"Decoder must have been set prior to calling this function."
);
avcodec_flush_buffers
(
codec_ctx
);
for
(
auto
&
ite
:
post_processes
)
{
ite
.
second
->
flush
();
}
}
// 0: some kind of success
// <0: Some error happened
int
StreamProcessor
::
send_frame
(
AVFrame
*
pF
rame
)
{
int
StreamProcessor
::
send_frame
(
AVFrame
*
f
rame
_
)
{
int
ret
=
0
;
for
(
auto
&
ite
:
sink
s
)
{
int
ret2
=
ite
.
second
.
process_frame
(
pF
rame
);
for
(
auto
&
ite
:
post_processe
s
)
{
int
ret2
=
ite
.
second
->
process_frame
(
f
rame
_
);
if
(
ret2
<
0
)
ret
=
ret2
;
}
...
...
@@ -110,9 +384,8 @@ int StreamProcessor::send_frame(AVFrame* pFrame) {
////////////////////////////////////////////////////////////////////////////////
// Retrieval
////////////////////////////////////////////////////////////////////////////////
c10
::
optional
<
torch
::
Tensor
>
StreamProcessor
::
pop_chunk
(
KeyType
key
)
{
return
sink
s
.
at
(
key
)
.
buffer
->
pop_chunk
();
c10
::
optional
<
Chunk
>
StreamProcessor
::
pop_chunk
(
KeyType
key
)
{
return
post_processe
s
.
at
(
key
)
->
pop_chunk
();
}
}
// namespace ffmpeg
}
// namespace torchaudio
}
// namespace torchaudio::io
torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h
View file @
ffeba11a
#pragma once
#include <torch/t
orch
.h>
#include <torch/t
ypes
.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/
decoder
.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/
sink
.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/
post_process
.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/
typedefs
.h>
#include <map>
namespace
torchaudio
{
namespace
ffmpeg
{
namespace
io
{
class
StreamProcessor
{
public:
using
KeyType
=
int
;
private:
AVFramePtr
pFrame1
;
AV
FramePtr
pFrame2
;
// Stream time base which is not stored in AVCodecContextPtr
AV
Rational
stream_time_base
;
// Components for decoding source media
double
decoder_time_base
;
// for debug
Decoder
decoder
;
AVCodecContextPtr
codec_ctx
{
nullptr
};
AVFramePtr
frame
{
alloc_avframe
()}
;
KeyType
current_key
=
0
;
std
::
map
<
KeyType
,
Sink
>
sinks
;
std
::
map
<
KeyType
,
std
::
unique_ptr
<
IPostDecodeProcess
>>
post_processes
;
// Used for precise seek.
// 0: no discard
// Positive Values: decoded frames with PTS values less than this are
// discarded.
// Negative values: UB. Should not happen.
int64_t
discard_before_pts
=
0
;
public:
StreamProcessor
(
AVCodecParameters
*
codecpar
,
const
c10
::
optional
<
std
::
string
>&
decoder_name
,
const
c10
::
optional
<
OptionDict
>&
decoder_option
,
const
torch
::
Device
&
device
);
explicit
StreamProcessor
(
const
AVRational
&
time_base
);
~
StreamProcessor
()
=
default
;
// Non-copyable
StreamProcessor
(
const
StreamProcessor
&
)
=
delete
;
...
...
@@ -48,21 +51,33 @@ class StreamProcessor {
// 3. Configure a buffer.
// 4. Return filter ID.
KeyType
add_stream
(
AVRational
input_time_base
,
AVCodecParameters
*
codecpar
,
int
frames_per_chunk
,
int
num_chunks
,
const
c10
::
optional
<
std
::
string
>&
filter_description
,
AVRational
frame_rate
,
const
std
::
string
&
filter_description
,
const
torch
::
Device
&
device
);
// 1. Remove the stream
void
remove_stream
(
KeyType
key
);
// Set discard
// The input timestamp must be expressed in AV_TIME_BASE unit.
void
set_discard_timestamp
(
int64_t
timestamp
);
void
set_decoder
(
const
AVCodecParameters
*
codecpar
,
const
c10
::
optional
<
std
::
string
>&
decoder_name
,
const
c10
::
optional
<
OptionDict
>&
decoder_option
,
const
torch
::
Device
&
device
);
//////////////////////////////////////////////////////////////////////////////
// Query methods
//////////////////////////////////////////////////////////////////////////////
std
::
string
get_filter_description
(
KeyType
key
)
const
;
[[
nodiscard
]]
std
::
string
get_filter_description
(
KeyType
key
)
const
;
[[
nodiscard
]]
FilterGraphOutputInfo
get_filter_output_info
(
KeyType
key
)
const
;
bool
is_buffer_ready
()
const
;
[[
nodiscard
]]
bool
is_decoder_set
()
const
;
//////////////////////////////////////////////////////////////////////////////
// The streaming process
...
...
@@ -85,8 +100,8 @@ class StreamProcessor {
//////////////////////////////////////////////////////////////////////////////
public:
// Get the chunk from the given filter result
c10
::
optional
<
torch
::
Tensor
>
pop_chunk
(
KeyType
key
);
c10
::
optional
<
Chunk
>
pop_chunk
(
KeyType
key
);
};
}
// namespace
ffmpeg
}
// namespace
io
}
// namespace torchaudio
torchaudio/csrc/ffmpeg/stream_reader/stream_reader.cpp
View file @
ffeba11a
...
...
@@ -5,69 +5,119 @@
#include <stdexcept>
#include <thread>
namespace
torchaudio
{
namespace
ffmpeg
{
namespace
torchaudio
::
io
{
using
KeyType
=
StreamProcessor
::
KeyType
;
//////////////////////////////////////////////////////////////////////////////
//
Helper method
s
//
Initialization / resource allocation
s
//////////////////////////////////////////////////////////////////////////////
void
StreamReader
::
validate_open_stream
()
const
{
TORCH_CHECK
(
pFormatContext
,
"Stream is not open."
);
}
namespace
{
AVFormatContext
*
get_input_format_context
(
const
std
::
string
&
src
,
const
c10
::
optional
<
std
::
string
>&
format
,
const
c10
::
optional
<
OptionDict
>&
option
,
AVIOContext
*
io_ctx
)
{
AVFormatContext
*
p
=
avformat_alloc_context
();
TORCH_CHECK
(
p
,
"Failed to allocate AVFormatContext."
);
if
(
io_ctx
)
{
p
->
pb
=
io_ctx
;
}
void
StreamReader
::
validate_src_stream_index
(
int
i
)
const
{
validate_open_stream
();
TORCH_CHECK
(
i
>=
0
&&
i
<
static_cast
<
int
>
(
pFormatContext
->
nb_streams
),
"Source stream index out of range"
);
}
auto
*
pInputFormat
=
[
&
format
]()
->
AVFORMAT_CONST
AVInputFormat
*
{
if
(
format
.
has_value
())
{
std
::
string
format_str
=
format
.
value
();
AVFORMAT_CONST
AVInputFormat
*
pInput
=
av_find_input_format
(
format_str
.
c_str
());
TORCH_CHECK
(
pInput
,
"Unsupported device/format:
\"
"
,
format_str
,
"
\"
"
);
return
pInput
;
}
return
nullptr
;
}();
void
StreamReader
::
validate_output_stream_index
(
int
i
)
const
{
TORCH_CHECK
(
i
>=
0
&&
i
<
static_cast
<
int
>
(
stream_indices
.
size
()),
"Output stream index out of range"
);
}
AVDictionary
*
opt
=
get_option_dict
(
option
);
int
ret
=
avformat_open_input
(
&
p
,
src
.
c_str
(),
pInputFormat
,
&
opt
);
clean_up_dict
(
opt
);
void
StreamReader
::
validate_src_stream_type
(
int
i
,
AVMediaType
type
)
{
validate_src_stream_index
(
i
);
TORCH_CHECK
(
pFormatContext
->
streams
[
i
]
->
codecpar
->
codec_type
==
type
,
"Stream "
,
i
,
" is not "
,
av_get_media_type_string
(
type
),
" stream."
);
ret
>=
0
,
"Failed to open the input
\"
"
,
src
,
"
\"
("
,
av_err2string
(
ret
),
")."
);
return
p
;
}
}
// namespace
//////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
StreamReader
::
StreamReader
(
AVFormatInputContextPtr
&&
p
)
:
pFormatContext
(
std
::
move
(
p
))
{
int
ret
=
avformat_find_stream_info
(
pFormatContext
,
nullptr
);
StreamReader
::
StreamReader
(
AVFormatContext
*
p
)
:
format_ctx
(
p
)
{
C10_LOG_API_USAGE_ONCE
(
"torchaudio.io.StreamReader"
);
int
ret
=
avformat_find_stream_info
(
format_ctx
,
nullptr
);
TORCH_CHECK
(
ret
>=
0
,
"Failed to find stream information: "
,
av_err2string
(
ret
));
processors
=
std
::
vector
<
std
::
unique_ptr
<
StreamProcessor
>>
(
pF
ormat
Context
->
nb_streams
);
for
(
int
i
=
0
;
i
<
pF
ormat
Context
->
nb_streams
;
++
i
)
{
switch
(
pF
ormat
Context
->
streams
[
i
]
->
codecpar
->
codec_type
)
{
std
::
vector
<
std
::
unique_ptr
<
StreamProcessor
>>
(
f
ormat
_ctx
->
nb_streams
);
for
(
int
i
=
0
;
i
<
f
ormat
_ctx
->
nb_streams
;
++
i
)
{
switch
(
f
ormat
_ctx
->
streams
[
i
]
->
codecpar
->
codec_type
)
{
case
AVMEDIA_TYPE_AUDIO
:
case
AVMEDIA_TYPE_VIDEO
:
break
;
default:
pF
ormat
Context
->
streams
[
i
]
->
discard
=
AVDISCARD_ALL
;
f
ormat
_ctx
->
streams
[
i
]
->
discard
=
AVDISCARD_ALL
;
}
}
}
StreamReader
::
StreamReader
(
AVIOContext
*
io_ctx
,
const
c10
::
optional
<
std
::
string
>&
format
,
const
c10
::
optional
<
OptionDict
>&
option
)
:
StreamReader
(
get_input_format_context
(
"Custom Input Context"
,
format
,
option
,
io_ctx
))
{}
StreamReader
::
StreamReader
(
const
std
::
string
&
src
,
const
c10
::
optional
<
std
::
string
>&
format
,
const
c10
::
optional
<
OptionDict
>&
option
)
:
StreamReader
(
get_input_format_context
(
src
,
format
,
option
,
nullptr
))
{}
//////////////////////////////////////////////////////////////////////////////
// Helper methods
//////////////////////////////////////////////////////////////////////////////
void
validate_open_stream
(
AVFormatContext
*
format_ctx
)
{
TORCH_CHECK
(
format_ctx
,
"Stream is not open."
);
}
void
validate_src_stream_index
(
AVFormatContext
*
format_ctx
,
int
i
)
{
validate_open_stream
(
format_ctx
);
TORCH_CHECK
(
i
>=
0
&&
i
<
static_cast
<
int
>
(
format_ctx
->
nb_streams
),
"Source stream index out of range"
);
}
void
validate_src_stream_type
(
AVFormatContext
*
format_ctx
,
int
i
,
AVMediaType
type
)
{
validate_src_stream_index
(
format_ctx
,
i
);
TORCH_CHECK
(
format_ctx
->
streams
[
i
]
->
codecpar
->
codec_type
==
type
,
"Stream "
,
i
,
" is not "
,
av_get_media_type_string
(
type
),
" stream."
);
}
////////////////////////////////////////////////////////////////////////////////
// Query methods
////////////////////////////////////////////////////////////////////////////////
int64_t
StreamReader
::
num_src_streams
()
const
{
return
pF
ormat
Context
->
nb_streams
;
return
f
ormat
_ctx
->
nb_streams
;
}
namespace
{
...
...
@@ -75,19 +125,20 @@ OptionDict parse_metadata(const AVDictionary* metadata) {
AVDictionaryEntry
*
tag
=
nullptr
;
OptionDict
ret
;
while
((
tag
=
av_dict_get
(
metadata
,
""
,
tag
,
AV_DICT_IGNORE_SUFFIX
)))
{
ret
.
insert
(
std
::
string
(
tag
->
key
),
std
::
string
(
tag
->
value
));
ret
.
emplace
(
std
::
string
(
tag
->
key
),
std
::
string
(
tag
->
value
));
}
return
ret
;
}
}
// namespace
OptionDict
StreamReader
::
get_metadata
()
const
{
return
parse_metadata
(
pF
ormat
Context
->
metadata
);
return
parse_metadata
(
f
ormat
_ctx
->
metadata
);
}
SrcStreamInfo
StreamReader
::
get_src_stream_info
(
int
i
)
const
{
validate_src_stream_index
(
i
);
AVStream
*
stream
=
pFormatContext
->
streams
[
i
];
validate_src_stream_index
(
format_ctx
,
i
);
AVStream
*
stream
=
format_ctx
->
streams
[
i
];
AVCodecParameters
*
codecpar
=
stream
->
codecpar
;
SrcStreamInfo
ret
;
...
...
@@ -127,34 +178,82 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
return
ret
;
}
namespace
{
AVCodecParameters
*
get_codecpar
()
{
AVCodecParameters
*
ptr
=
avcodec_parameters_alloc
();
TORCH_CHECK
(
ptr
,
"Failed to allocate resource."
);
return
ptr
;
}
}
// namespace
StreamParams
StreamReader
::
get_src_stream_params
(
int
i
)
{
validate_src_stream_index
(
format_ctx
,
i
);
AVStream
*
stream
=
format_ctx
->
streams
[
i
];
AVCodecParametersPtr
codec_params
(
get_codecpar
());
int
ret
=
avcodec_parameters_copy
(
codec_params
,
stream
->
codecpar
);
TORCH_CHECK
(
ret
>=
0
,
"Failed to copy the stream's codec parameters. ("
,
av_err2string
(
ret
),
")"
);
return
{
std
::
move
(
codec_params
),
stream
->
time_base
,
i
};
}
int64_t
StreamReader
::
num_out_streams
()
const
{
return
static_cast
<
int64_t
>
(
stream_indices
.
size
());
}
OutputStreamInfo
StreamReader
::
get_out_stream_info
(
int
i
)
const
{
validate_output_stream_index
(
i
);
OutputStreamInfo
ret
;
TORCH_CHECK
(
i
>=
0
&&
static_cast
<
size_t
>
(
i
)
<
stream_indices
.
size
(),
"Output stream index out of range"
);
int
i_src
=
stream_indices
[
i
].
first
;
KeyType
key
=
stream_indices
[
i
].
second
;
FilterGraphOutputInfo
info
=
processors
[
i_src
]
->
get_filter_output_info
(
key
);
OutputStreamInfo
ret
;
ret
.
source_index
=
i_src
;
ret
.
filter_description
=
processors
[
i_src
]
->
get_filter_description
(
key
);
ret
.
media_type
=
info
.
type
;
ret
.
format
=
info
.
format
;
switch
(
info
.
type
)
{
case
AVMEDIA_TYPE_AUDIO
:
ret
.
sample_rate
=
info
.
sample_rate
;
ret
.
num_channels
=
info
.
num_channels
;
break
;
case
AVMEDIA_TYPE_VIDEO
:
ret
.
width
=
info
.
width
;
ret
.
height
=
info
.
height
;
ret
.
frame_rate
=
info
.
frame_rate
;
break
;
default:
;
}
return
ret
;
}
int64_t
StreamReader
::
find_best_audio_stream
()
const
{
return
av_find_best_stream
(
pF
ormat
Context
,
AVMEDIA_TYPE_AUDIO
,
-
1
,
-
1
,
nullptr
,
0
);
f
ormat
_ctx
,
AVMEDIA_TYPE_AUDIO
,
-
1
,
-
1
,
nullptr
,
0
);
}
int64_t
StreamReader
::
find_best_video_stream
()
const
{
return
av_find_best_stream
(
pF
ormat
Context
,
AVMEDIA_TYPE_VIDEO
,
-
1
,
-
1
,
nullptr
,
0
);
f
ormat
_ctx
,
AVMEDIA_TYPE_VIDEO
,
-
1
,
-
1
,
nullptr
,
0
);
}
bool
StreamReader
::
is_buffer_ready
()
const
{
for
(
const
auto
&
it
:
processors
)
{
if
(
it
&&
!
it
->
is_buffer_ready
())
{
return
false
;
if
(
processors
.
empty
())
{
// If no decoding output streams exist, then determine overall readiness
// from the readiness of packet buffer.
return
packet_buffer
->
has_packets
();
}
else
{
// Otherwise, determine readiness solely from the readiness of the decoding
// output streams.
for
(
const
auto
&
it
:
processors
)
{
if
(
it
&&
!
it
->
is_buffer_ready
())
{
return
false
;
}
}
}
return
true
;
...
...
@@ -163,15 +262,42 @@ bool StreamReader::is_buffer_ready() const {
////////////////////////////////////////////////////////////////////////////////
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void
StreamReader
::
seek
(
double
timestamp
)
{
TORCH_CHECK
(
timestamp
>=
0
,
"timestamp must be non-negative."
);
void
StreamReader
::
seek
(
double
timestamp_s
,
int64_t
mode
)
{
TORCH_CHECK
(
timestamp_s
>=
0
,
"timestamp must be non-negative."
);
TORCH_CHECK
(
format_ctx
->
nb_streams
>
0
,
"At least one stream must exist in this context"
);
int64_t
timestamp_av_tb
=
static_cast
<
int64_t
>
(
timestamp_s
*
AV_TIME_BASE
);
int64_t
ts
=
static_cast
<
int64_t
>
(
timestamp
*
AV_TIME_BASE
);
int
ret
=
avformat_seek_file
(
pFormatContext
,
-
1
,
INT64_MIN
,
ts
,
INT64_MAX
,
0
);
TORCH_CHECK
(
ret
>=
0
,
"Failed to seek. ("
+
av_err2string
(
ret
)
+
".)"
);
int
flag
=
AVSEEK_FLAG_BACKWARD
;
switch
(
mode
)
{
case
0
:
// reset seek_timestap as it is only used for precise seek
seek_timestamp
=
0
;
break
;
case
1
:
flag
|=
AVSEEK_FLAG_ANY
;
// reset seek_timestap as it is only used for precise seek
seek_timestamp
=
0
;
break
;
case
2
:
seek_timestamp
=
timestamp_av_tb
;
break
;
default:
TORCH_CHECK
(
false
,
"Invalid mode value: "
,
mode
);
}
int
ret
=
av_seek_frame
(
format_ctx
,
-
1
,
timestamp_av_tb
,
flag
);
if
(
ret
<
0
)
{
seek_timestamp
=
0
;
TORCH_CHECK
(
false
,
"Failed to seek. ("
+
av_err2string
(
ret
)
+
".)"
);
}
for
(
const
auto
&
it
:
processors
)
{
if
(
it
)
{
it
->
flush
();
it
->
set_discard_timestamp
(
seek_timestamp
);
}
}
}
...
...
@@ -188,7 +314,7 @@ void StreamReader::add_audio_stream(
AVMEDIA_TYPE_AUDIO
,
static_cast
<
int
>
(
frames_per_chunk
),
static_cast
<
int
>
(
num_chunks
),
filter_desc
,
filter_desc
.
value_or
(
"anull"
)
,
decoder
,
decoder_option
,
torch
::
Device
(
torch
::
DeviceType
::
CPU
));
...
...
@@ -209,9 +335,7 @@ void StreamReader::add_video_stream(
#ifdef USE_CUDA
torch
::
Device
d
{
hw_accel
.
value
()};
TORCH_CHECK
(
d
.
type
()
==
c10
::
DeviceType
::
CUDA
,
"Only CUDA is supported for hardware acceleration. Found: "
,
device
.
str
());
d
.
is_cuda
(),
"Only CUDA is supported for HW acceleration. Found: "
,
d
);
return
d
;
#else
TORCH_CHECK
(
...
...
@@ -225,47 +349,75 @@ void StreamReader::add_video_stream(
AVMEDIA_TYPE_VIDEO
,
static_cast
<
int
>
(
frames_per_chunk
),
static_cast
<
int
>
(
num_chunks
),
filter_desc
,
filter_desc
.
value_or
(
"null"
)
,
decoder
,
decoder_option
,
device
);
}
void
StreamReader
::
add_packet_stream
(
int
i
)
{
validate_src_stream_index
(
format_ctx
,
i
);
if
(
!
packet_buffer
)
{
packet_buffer
=
std
::
make_unique
<
PacketBuffer
>
();
}
packet_stream_indices
.
emplace
(
i
);
}
void
StreamReader
::
add_stream
(
int
i
,
AVMediaType
media_type
,
int
frames_per_chunk
,
int
num_chunks
,
const
c10
::
optional
<
std
::
string
>
&
filter_desc
,
const
std
::
string
&
filter_desc
,
const
c10
::
optional
<
std
::
string
>&
decoder
,
const
c10
::
optional
<
OptionDict
>&
decoder_option
,
const
torch
::
Device
&
device
)
{
validate_src_stream_type
(
i
,
media_type
);
validate_src_stream_type
(
format_ctx
,
i
,
media_type
);
AVStream
*
stream
=
pF
ormat
Context
->
streams
[
i
];
// When media source is file-like object, it is possible that source codec
is
// not detected properly.
AVStream
*
stream
=
f
ormat
_ctx
->
streams
[
i
];
// When media source is file-like object, it is possible that source codec
//
is
not detected properly.
TORCH_CHECK
(
stream
->
codecpar
->
format
!=
-
1
,
"Failed to detect the source stream format."
);
if
(
!
processors
[
i
])
{
processors
[
i
]
=
std
::
make_unique
<
StreamProcessor
>
(
processors
[
i
]
=
std
::
make_unique
<
StreamProcessor
>
(
stream
->
time_base
);
processors
[
i
]
->
set_discard_timestamp
(
seek_timestamp
);
}
if
(
!
processors
[
i
]
->
is_decoder_set
())
{
processors
[
i
]
->
set_decoder
(
stream
->
codecpar
,
decoder
,
decoder_option
,
device
);
}
else
{
TORCH_CHECK
(
!
decoder
&&
(
!
decoder_option
||
decoder_option
.
value
().
size
()
==
0
),
"Decoder options were provided, but the decoder has already been initialized."
)
}
stream
->
discard
=
AVDISCARD_DEFAULT
;
auto
frame_rate
=
[
&
]()
->
AVRational
{
switch
(
media_type
)
{
case
AVMEDIA_TYPE_AUDIO
:
return
AVRational
{
0
,
1
};
case
AVMEDIA_TYPE_VIDEO
:
return
av_guess_frame_rate
(
format_ctx
,
stream
,
nullptr
);
default:
TORCH_INTERNAL_ASSERT
(
false
,
"Unexpected media type is given: "
,
av_get_media_type_string
(
media_type
));
}
}();
int
key
=
processors
[
i
]
->
add_stream
(
stream
->
time_base
,
stream
->
codecpar
,
frames_per_chunk
,
num_chunks
,
filter_desc
,
device
);
frames_per_chunk
,
num_chunks
,
frame_rate
,
filter_desc
,
device
);
stream_indices
.
push_back
(
std
::
make_pair
<>
(
i
,
key
));
}
void
StreamReader
::
remove_stream
(
int64_t
i
)
{
validate_output_stream_index
(
static_cast
<
int
>
(
i
));
TORCH_CHECK
(
i
>=
0
&&
static_cast
<
size_t
>
(
i
)
<
stream_indices
.
size
(),
"Output stream index out of range"
);
auto
it
=
stream_indices
.
begin
()
+
i
;
int
iP
=
it
->
first
;
processors
[
iP
]
->
remove_stream
(
it
->
second
);
...
...
@@ -293,7 +445,7 @@ void StreamReader::remove_stream(int64_t i) {
// 1: It's done, caller should stop calling
// <0: Some error happened
int
StreamReader
::
process_packet
()
{
int
ret
=
av_read_frame
(
pF
ormat
Context
,
p
P
acket
);
int
ret
=
av_read_frame
(
f
ormat
_ctx
,
packet
);
if
(
ret
==
AVERROR_EOF
)
{
ret
=
drain
();
return
(
ret
<
0
)
?
ret
:
1
;
...
...
@@ -301,12 +453,21 @@ int StreamReader::process_packet() {
if
(
ret
<
0
)
{
return
ret
;
}
AutoPacketUnref
packet
{
pPacket
};
auto
&
processor
=
processors
[
pPacket
->
stream_index
];
AutoPacketUnref
auto_unref
{
packet
};
int
stream_index
=
packet
->
stream_index
;
if
(
packet_stream_indices
.
count
(
stream_index
))
{
packet_buffer
->
push_packet
(
packet
);
}
auto
&
processor
=
processors
[
stream_index
];
if
(
!
processor
)
{
return
0
;
}
ret
=
processor
->
process_packet
(
packet
);
return
(
ret
<
0
)
?
ret
:
0
;
}
...
...
@@ -344,6 +505,39 @@ int StreamReader::process_packet_block(double timeout, double backoff) {
}
}
void
StreamReader
::
process_all_packets
()
{
int64_t
ret
=
0
;
do
{
ret
=
process_packet
();
}
while
(
!
ret
);
}
int
StreamReader
::
process_packet
(
const
c10
::
optional
<
double
>&
timeout
,
const
double
backoff
)
{
int
code
=
[
&
]()
->
int
{
if
(
timeout
.
has_value
())
{
return
process_packet_block
(
timeout
.
value
(),
backoff
);
}
return
process_packet
();
}();
TORCH_CHECK
(
code
>=
0
,
"Failed to process a packet. ("
+
av_err2string
(
code
)
+
"). "
);
return
code
;
}
int
StreamReader
::
fill_buffer
(
const
c10
::
optional
<
double
>&
timeout
,
const
double
backoff
)
{
while
(
!
is_buffer_ready
())
{
int
code
=
process_packet
(
timeout
,
backoff
);
if
(
code
!=
0
)
{
return
code
;
}
}
return
0
;
}
// <0: Some error happened.
int
StreamReader
::
drain
()
{
int
ret
=
0
,
tmp
=
0
;
...
...
@@ -358,13 +552,58 @@ int StreamReader::drain() {
return
ret
;
}
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
StreamReader
::
pop_chunks
()
{
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
ret
;
std
::
vector
<
c10
::
optional
<
Chunk
>>
StreamReader
::
pop_chunks
()
{
std
::
vector
<
c10
::
optional
<
Chunk
>>
ret
;
ret
.
reserve
(
static_cast
<
size_t
>
(
num_out_streams
()));
for
(
auto
&
i
:
stream_indices
)
{
ret
.
push
_back
(
processors
[
i
.
first
]
->
pop_chunk
(
i
.
second
));
ret
.
emplace
_back
(
processors
[
i
.
first
]
->
pop_chunk
(
i
.
second
));
}
return
ret
;
}
}
// namespace ffmpeg
}
// namespace torchaudio
std
::
vector
<
AVPacketPtr
>
StreamReader
::
pop_packets
()
{
return
packet_buffer
->
pop_packets
();
}
//////////////////////////////////////////////////////////////////////////////
// StreamReaderCustomIO
//////////////////////////////////////////////////////////////////////////////
namespace
detail
{
namespace
{
AVIOContext
*
get_io_context
(
void
*
opaque
,
int
buffer_size
,
int
(
*
read_packet
)(
void
*
opaque
,
uint8_t
*
buf
,
int
buf_size
),
int64_t
(
*
seek
)(
void
*
opaque
,
int64_t
offset
,
int
whence
))
{
unsigned
char
*
buffer
=
static_cast
<
unsigned
char
*>
(
av_malloc
(
buffer_size
));
TORCH_CHECK
(
buffer
,
"Failed to allocate buffer."
);
AVIOContext
*
io_ctx
=
avio_alloc_context
(
buffer
,
buffer_size
,
0
,
opaque
,
read_packet
,
nullptr
,
seek
);
if
(
!
io_ctx
)
{
av_freep
(
&
buffer
);
TORCH_CHECK
(
false
,
"Failed to allocate AVIOContext."
);
}
return
io_ctx
;
}
}
// namespace
CustomInput
::
CustomInput
(
void
*
opaque
,
int
buffer_size
,
int
(
*
read_packet
)(
void
*
opaque
,
uint8_t
*
buf
,
int
buf_size
),
int64_t
(
*
seek
)(
void
*
opaque
,
int64_t
offset
,
int
whence
))
:
io_ctx
(
get_io_context
(
opaque
,
buffer_size
,
read_packet
,
seek
))
{}
}
// namespace detail
StreamReaderCustomIO
::
StreamReaderCustomIO
(
void
*
opaque
,
const
c10
::
optional
<
std
::
string
>&
format
,
int
buffer_size
,
int
(
*
read_packet
)(
void
*
opaque
,
uint8_t
*
buf
,
int
buf_size
),
int64_t
(
*
seek
)(
void
*
opaque
,
int64_t
offset
,
int
whence
),
const
c10
::
optional
<
OptionDict
>&
option
)
:
CustomInput
(
opaque
,
buffer_size
,
read_packet
,
seek
),
StreamReader
(
io_ctx
,
format
,
option
)
{}
}
// namespace torchaudio::io
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment