Commit 8497ee91 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor compat (#3518)

Summary:
The I/O functions in _compat module was introduced there so that
everything related to FFmpeg is in torchaudio.io and FFmpeg library
initialization can be carried out in `torchaudio.io.__init__`.

Now that this constraint is removed, (all the initialization happens
at `torchaudio._extension.__init__`) and `_compat` is only used by
FFmpeg dispatcher backend, we move the module to `torchaudio._backend`
for better locality.

Pull Request resolved: https://github.com/pytorch/audio/pull/3518

Reviewed By: huangruizhe

Differential Revision: D47877412

Pulled By: mthrok

fbshipit-source-id: aa18c8cb6e5d5360950df5158c33c653e37c565f
parent 61cbf791
...@@ -42,13 +42,12 @@ attention due to its robustness against noise. ...@@ -42,13 +42,12 @@ attention due to its robustness against noise.
""" """
import numpy as np
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
import torchvision import torchvision
import numpy as np
import sentencepiece as spm
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
......
...@@ -5,9 +5,9 @@ import tarfile ...@@ -5,9 +5,9 @@ import tarfile
from functools import partial from functools import partial
from parameterized import parameterized from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_load_func from torchaudio._backend.utils import get_load_func
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.io._compat import _parse_save_args
from torchaudio_unittest.backend.dispatcher.sox.common import name_func from torchaudio_unittest.backend.dispatcher.sox.common import name_func
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
......
...@@ -7,8 +7,8 @@ from functools import partial ...@@ -7,8 +7,8 @@ from functools import partial
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _parse_save_args
from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
......
...@@ -9,9 +9,10 @@ from torchaudio.io import StreamWriter ...@@ -9,9 +9,10 @@ from torchaudio.io import StreamWriter
if torchaudio._extension._FFMPEG_EXT is not None: if torchaudio._extension._FFMPEG_EXT is not None:
StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
else:
StreamReaderFileObj = object
# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
def info_audio( def info_audio(
src: str, src: str,
format: Optional[str], format: Optional[str],
...@@ -241,7 +242,6 @@ def _parse_save_args( ...@@ -241,7 +242,6 @@ def _parse_save_args(
return muxer, encoder, sample_fmt return muxer, encoder, sample_fmt
# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
def save_audio( def save_audio(
uri: Union[BinaryIO, str, os.PathLike], uri: Union[BinaryIO, str, os.PathLike],
src: torch.Tensor, src: torch.Tensor,
......
...@@ -5,12 +5,12 @@ from functools import lru_cache ...@@ -5,12 +5,12 @@ from functools import lru_cache
from typing import BinaryIO, Dict, Optional, Tuple, Union from typing import BinaryIO, Dict, Optional, Tuple, Union
import torch import torch
import torchaudio.backend.soundfile_backend as soundfile_backend
from torchaudio._extension import _FFMPEG_EXT, _SOX_INITIALIZED from torchaudio._extension import _FFMPEG_EXT, _SOX_INITIALIZED
from torchaudio.backend import soundfile_backend
from torchaudio.backend.common import AudioMetaData from torchaudio.backend.common import AudioMetaData
if _FFMPEG_EXT is not None: from . import ffmpeg
from torchaudio.io._compat import info_audio, info_audio_fileobj, load_audio, load_audio_fileobj, save_audio
class Backend(ABC): class Backend(ABC):
...@@ -80,9 +80,9 @@ class FFmpegBackend(Backend): ...@@ -80,9 +80,9 @@ class FFmpegBackend(Backend):
@staticmethod @staticmethod
def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData: def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
if hasattr(uri, "read"): if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size) metadata = ffmpeg.info_audio_fileobj(uri, format, buffer_size=buffer_size)
else: else:
metadata = info_audio(os.path.normpath(uri), format) metadata = ffmpeg.info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample) metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding) metadata.encoding = _map_encoding(metadata.encoding)
return metadata return metadata
...@@ -98,7 +98,7 @@ class FFmpegBackend(Backend): ...@@ -98,7 +98,7 @@ class FFmpegBackend(Backend):
buffer_size: int = 4096, buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
if hasattr(uri, "read"): if hasattr(uri, "read"):
return load_audio_fileobj( return ffmpeg.load_audio_fileobj(
uri, uri,
frame_offset, frame_offset,
num_frames, num_frames,
...@@ -108,7 +108,7 @@ class FFmpegBackend(Backend): ...@@ -108,7 +108,7 @@ class FFmpegBackend(Backend):
buffer_size, buffer_size,
) )
else: else:
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format) return ffmpeg.load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
@staticmethod @staticmethod
def save( def save(
...@@ -121,7 +121,7 @@ class FFmpegBackend(Backend): ...@@ -121,7 +121,7 @@ class FFmpegBackend(Backend):
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
) -> None: ) -> None:
save_audio( ffmpeg.save_audio(
uri, uri,
src, src,
sample_rate, sample_rate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment