Commit 8497ee91 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor compat (#3518)

Summary:
The I/O functions in _compat module was introduced there so that
everything related to FFmpeg is in torchaudio.io and FFmpeg library
initialization can be carried out in `torchaudio.io.__init__`.

Now that this constraint is removed, (all the initialization happens
at `torchaudio._extension.__init__`) and `_compat` is only used by
FFmpeg dispatcher backend, we move the module to `torchaudio._backend`
for better locality.

Pull Request resolved: https://github.com/pytorch/audio/pull/3518

Reviewed By: huangruizhe

Differential Revision: D47877412

Pulled By: mthrok

fbshipit-source-id: aa18c8cb6e5d5360950df5158c33c653e37c565f
parent 61cbf791
......@@ -42,13 +42,12 @@ attention due to its robustness against noise.
"""
import numpy as np
import sentencepiece as spm
import torch
import torchaudio
import torchvision
import numpy as np
import sentencepiece as spm
######################################################################
# Overview
# --------
......
......@@ -5,9 +5,9 @@ import tarfile
from functools import partial
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_load_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.io._compat import _parse_save_args
from torchaudio_unittest.backend.dispatcher.sox.common import name_func
from torchaudio_unittest.common_utils import (
......
......@@ -7,8 +7,8 @@ from functools import partial
import torch
from parameterized import parameterized
from torchaudio._backend.ffmpeg import _parse_save_args
from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _parse_save_args
from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
......
......@@ -9,9 +9,10 @@ from torchaudio.io import StreamWriter
if torchaudio._extension._FFMPEG_EXT is not None:
StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
else:
StreamReaderFileObj = object
# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
def info_audio(
src: str,
format: Optional[str],
......@@ -241,7 +242,6 @@ def _parse_save_args(
return muxer, encoder, sample_fmt
# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
def save_audio(
uri: Union[BinaryIO, str, os.PathLike],
src: torch.Tensor,
......
......@@ -5,12 +5,12 @@ from functools import lru_cache
from typing import BinaryIO, Dict, Optional, Tuple, Union
import torch
import torchaudio.backend.soundfile_backend as soundfile_backend
from torchaudio._extension import _FFMPEG_EXT, _SOX_INITIALIZED
from torchaudio.backend import soundfile_backend
from torchaudio.backend.common import AudioMetaData
if _FFMPEG_EXT is not None:
from torchaudio.io._compat import info_audio, info_audio_fileobj, load_audio, load_audio_fileobj, save_audio
from . import ffmpeg
class Backend(ABC):
......@@ -80,9 +80,9 @@ class FFmpegBackend(Backend):
@staticmethod
def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
metadata = ffmpeg.info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(os.path.normpath(uri), format)
metadata = ffmpeg.info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
......@@ -98,7 +98,7 @@ class FFmpegBackend(Backend):
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
if hasattr(uri, "read"):
return load_audio_fileobj(
return ffmpeg.load_audio_fileobj(
uri,
frame_offset,
num_frames,
......@@ -108,7 +108,7 @@ class FFmpegBackend(Backend):
buffer_size,
)
else:
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
return ffmpeg.load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
@staticmethod
def save(
......@@ -121,7 +121,7 @@ class FFmpegBackend(Backend):
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
save_audio(
ffmpeg.save_audio(
uri,
src,
sample_rate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment