Commit 98b3ac17 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update the guard mechanism for FFmpeg-related features (#3028)

Summary:
Instead of raising an error when lazy import happens, this method allows to import features, and raises an error when the feature is being used.

This makes it easy to adopt the same error mechanism across different modules. It is how it's done for sox-related features.

Pull Request resolved: https://github.com/pytorch/audio/pull/3028

Reviewed By: xiaohui-zhang

Differential Revision: D42966976

Pulled By: mthrok

fbshipit-source-id: 423dfe0b8a3970cd07f20e841c794c7f2809f993
parent a0f8af4b
......@@ -4,7 +4,7 @@ import sys
from torchaudio._internal.module_utils import fail_with_message, is_module_available, no_op
from .utils import _check_cuda_version, _init_dll_path, _init_ffmpeg, _init_sox, _load_lib # noqa
from .utils import _check_cuda_version, _fail_since_no_ffmpeg, _init_dll_path, _init_ffmpeg, _init_sox, _load_lib
_LG = logging.getLogger(__name__)
......@@ -16,6 +16,7 @@ _LG = logging.getLogger(__name__)
__all__ = [
"fail_if_no_kaldi",
"fail_if_no_sox",
"fail_if_no_ffmpeg",
"_check_cuda_version",
"_IS_TORCHAUDIO_EXT_AVAILABLE",
"_IS_KALDI_AVAILABLE",
......@@ -85,3 +86,5 @@ fail_if_no_sox = (
"requires sox extension, but TorchAudio is not compiled with it. Please build TorchAudio with libsox support."
)
)
fail_if_no_ffmpeg = no_op if _FFMPEG_INITIALIZED else _fail_since_no_ffmpeg
"""Module to implement logics used for initializing extensions.
The implementations here should be stateless.
They should not depend on external state.
Anything that depends on external state should happen in __init__.py
"""
import os
from functools import wraps
from pathlib import Path
import torch
......@@ -113,3 +122,22 @@ def _check_cuda_version():
"Please install the TorchAudio version that matches your PyTorch version."
)
return version
def _fail_since_no_ffmpeg(func):
@wraps(func)
def wrapped(*_args, **_kwargs):
try:
# Note:
# We run _init_ffmpeg again just to show users the stacktrace.
# _init_ffmpeg would not succeed here.
_init_ffmpeg()
except Exception as err:
raise RuntimeError(
f"{func.__name__} requires FFmpeg extension which is not available. "
"Please refer to the stacktrace above for how to resolve this."
) from err
# This should not happen in normal execution, but just in case.
return func(*_args, **_kwargs)
return wrapped
import torchaudio
from ._stream_reader import StreamReader
from ._stream_writer import StreamWriter
_STREAM_READER = [
"StreamReader",
]
_STREAM_WRITER = [
__all__ = [
"StreamReader",
"StreamWriter",
]
_PLAYBACK = [
"play_audio",
]
_LAZILY_IMPORTED = _STREAM_READER + _STREAM_WRITER + _PLAYBACK
def __getattr__(name: str):
if name in _LAZILY_IMPORTED:
if not torchaudio._extension._FFMPEG_INITIALIZED:
torchaudio._extension._init_ffmpeg()
if name in _STREAM_READER:
from . import _stream_reader
item = getattr(_stream_reader, name)
elif name in _STREAM_WRITER:
from . import _stream_writer
item = getattr(_stream_writer, name)
elif name in _PLAYBACK:
from . import _playback
item = getattr(_playback, name)
globals()[name] = item
return item
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + _LAZILY_IMPORTED)
__all__ = []
......@@ -16,6 +16,7 @@ dict_format = {
}
@torchaudio._extension.fail_if_no_ffmpeg
def play_audio(
waveform: torch.Tensor,
sample_rate: Optional[float],
......
......@@ -381,6 +381,7 @@ _format_video_args = _format_doc(
)
@torchaudio._extension.fail_if_no_ffmpeg
class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk.
......@@ -516,17 +517,20 @@ class StreamReader:
"""
return self._be.get_metadata()
def get_src_stream_info(self, i: int) -> SourceStream:
def get_src_stream_info(self, i: int) -> Union[SourceStream, SourceAudioStream, SourceVideoStream]:
"""Get the metadata of source stream
Args:
i (int): Stream index.
Returns:
SourceStream
Information about the source stream.
If the source stream is audio type, then :class:`SourceAudioStream` returned.
If it is video type, then :class:`SourceVideoStream` is returned.
Otherwise :class:`SourceStream` class is returned.
"""
return _parse_si(self._be.get_src_stream_info(i))
def get_out_stream_info(self, i: int) -> torchaudio.io.OutputStream:
def get_out_stream_info(self, i: int) -> OutputStream:
"""Get the metadata of output stream
Args:
......
......@@ -48,6 +48,7 @@ _format_common_args = _format_doc(
)
@torchaudio._extension.fail_if_no_ffmpeg
class StreamWriter:
"""Encode and write audio/video streams chunk by chunk
......
......@@ -5,8 +5,10 @@ It affects functionalities in :py:mod:`torchaudio.io` (and indirectly :py:func:`
from typing import Dict, List, Tuple
import torch
import torchaudio
@torchaudio._extension.fail_if_no_ffmpeg
def get_versions() -> Dict[str, Tuple[int]]:
"""Get the versions of FFmpeg libraries
......@@ -17,6 +19,7 @@ def get_versions() -> Dict[str, Tuple[int]]:
return torch.ops.torchaudio.ffmpeg_get_versions()
@torchaudio._extension.fail_if_no_ffmpeg
def get_log_level() -> int:
"""Get the log level of FFmpeg.
......@@ -25,6 +28,7 @@ def get_log_level() -> int:
return torch.ops.torchaudio.ffmpeg_get_log_level()
@torchaudio._extension.fail_if_no_ffmpeg
def set_log_level(level: int):
"""Set the log level of FFmpeg (libavformat etc)
......@@ -61,6 +65,7 @@ def set_log_level(level: int):
torch.ops.torchaudio.ffmpeg_set_log_level(level)
@torchaudio._extension.fail_if_no_ffmpeg
def get_demuxers() -> Dict[str, str]:
"""Get the available demuxers.
......@@ -78,6 +83,7 @@ def get_demuxers() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_demuxers()
@torchaudio._extension.fail_if_no_ffmpeg
def get_muxers() -> Dict[str, str]:
"""Get the available muxers.
......@@ -96,6 +102,7 @@ def get_muxers() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_muxers()
@torchaudio._extension.fail_if_no_ffmpeg
def get_audio_decoders() -> Dict[str, str]:
"""Get the available audio decoders.
......@@ -114,6 +121,7 @@ def get_audio_decoders() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_audio_decoders()
@torchaudio._extension.fail_if_no_ffmpeg
def get_audio_encoders() -> Dict[str, str]:
"""Get the available audio encoders.
......@@ -133,6 +141,7 @@ def get_audio_encoders() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_audio_encoders()
@torchaudio._extension.fail_if_no_ffmpeg
def get_video_decoders() -> Dict[str, str]:
"""Get the available video decoders.
......@@ -152,6 +161,7 @@ def get_video_decoders() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_video_decoders()
@torchaudio._extension.fail_if_no_ffmpeg
def get_video_encoders() -> Dict[str, str]:
"""Get the available video encoders.
......@@ -172,6 +182,7 @@ def get_video_encoders() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_video_encoders()
@torchaudio._extension.fail_if_no_ffmpeg
def get_input_devices() -> Dict[str, str]:
"""Get the available input devices.
......@@ -187,6 +198,7 @@ def get_input_devices() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_input_devices()
@torchaudio._extension.fail_if_no_ffmpeg
def get_output_devices() -> Dict[str, str]:
"""Get the available output devices.
......@@ -201,6 +213,7 @@ def get_output_devices() -> Dict[str, str]:
return torch.ops.torchaudio.ffmpeg_get_output_devices()
@torchaudio._extension.fail_if_no_ffmpeg
def get_input_protocols() -> List[str]:
"""Get the supported input protocols.
......@@ -214,6 +227,7 @@ def get_input_protocols() -> List[str]:
return torch.ops.torchaudio.ffmpeg_get_input_protocols()
@torchaudio._extension.fail_if_no_ffmpeg
def get_output_protocols() -> List[str]:
"""Get the supported output protocols.
......@@ -227,6 +241,7 @@ def get_output_protocols() -> List[str]:
return torch.ops.torchaudio.ffmpeg_get_output_protocols()
@torchaudio._extension.fail_if_no_ffmpeg
def get_build_config() -> str:
"""Get the FFmpeg build configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment