Unverified Commit c29598d5 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing _backend (#527)

* add inline typing

* fix error

* minor change

* minor fix
parent 98fe8b46
from functools import wraps from functools import wraps
from typing import Any, List, Union
import platform import platform
import torch import torch
from torch import Tensor
from . import _soundfile_backend, _sox_backend from . import _soundfile_backend, _sox_backend
...@@ -10,11 +12,11 @@ _audio_backend = "soundfile" if platform.system() == "Windows" else "sox" ...@@ -10,11 +12,11 @@ _audio_backend = "soundfile" if platform.system() == "Windows" else "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend} _audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def set_audio_backend(backend): def set_audio_backend(backend: str) -> None:
""" """
Specifies the package used to load. Specifies the package used to load.
Args: Args:
backend (string): Name of the backend. One of {}. backend (str): Name of the backend. One of {}.
""".format(_audio_backends.keys()) """.format(_audio_backends.keys())
global _audio_backend global _audio_backend
if backend not in _audio_backends: if backend not in _audio_backends:
...@@ -24,14 +26,14 @@ def set_audio_backend(backend): ...@@ -24,14 +26,14 @@ def set_audio_backend(backend):
_audio_backend = backend _audio_backend = backend
def get_audio_backend(): def get_audio_backend() -> str:
""" """
Gets the name of the package used to load. Gets the name of the package used to load.
""" """
return _audio_backend return _audio_backend
def _get_audio_backend_module(): def _get_audio_backend_module() -> Any:
""" """
Gets the module backend to load. Gets the module backend to load.
""" """
...@@ -39,7 +41,7 @@ def _get_audio_backend_module(): ...@@ -39,7 +41,7 @@ def _get_audio_backend_module():
return _audio_backends[backend] return _audio_backends[backend]
def _audio_backend_guard(backends): def _audio_backend_guard(backends: Union[str, List[str]]) -> Any:
if isinstance(backends, str): if isinstance(backends, str):
backends = [backends] backends = [backends]
...@@ -55,7 +57,7 @@ def _audio_backend_guard(backends): ...@@ -55,7 +57,7 @@ def _audio_backend_guard(backends):
return decorator return decorator
def check_input(src): def check_input(src: Tensor) -> None:
if not torch.is_tensor(src): if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src)) raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda: if src.is_cuda:
......
import os import os
from typing import Any, Optional, Tuple, Union
import torch import torch
from torch import Tensor
_subtype_to_precision = { _subtype_to_precision = {
'PCM_S8': 8, 'PCM_S8': 8,
...@@ -12,7 +14,11 @@ _subtype_to_precision = { ...@@ -12,7 +14,11 @@ _subtype_to_precision = {
class SignalInfo: class SignalInfo:
def __init__(self, channels=None, rate=None, precision=None, length=None): def __init__(self,
channels: Optional[int] = None,
rate: Optional[float] = None,
precision: Optional[int] = None,
length: Optional[int] = None) -> None:
self.channels = channels self.channels = channels
self.rate = rate self.rate = rate
self.precision = precision self.precision = precision
...@@ -20,16 +26,14 @@ class SignalInfo: ...@@ -20,16 +26,14 @@ class SignalInfo:
class EncodingInfo: class EncodingInfo:
def __init__( def __init__(self,
self, encoding: Any = None,
encoding=None, bits_per_sample: Optional[int] = None,
bits_per_sample=None, compression: Optional[float] = None,
compression=None, reverse_bytes: Any = None,
reverse_bytes=None, reverse_nibbles: Any = None,
reverse_nibbles=None, reverse_bits: Any = None,
reverse_bits=None, opposite_endian: Optional[bool] = None) -> None:
opposite_endian=None
):
self.encoding = encoding self.encoding = encoding
self.bits_per_sample = bits_per_sample self.bits_per_sample = bits_per_sample
self.compression = compression self.compression = compression
...@@ -39,24 +43,22 @@ class EncodingInfo: ...@@ -39,24 +43,22 @@ class EncodingInfo:
self.opposite_endian = opposite_endian self.opposite_endian = opposite_endian
def check_input(src): def check_input(src: Tensor) -> None:
if not torch.is_tensor(src): if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src)) raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda: if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src)) raise TypeError("Expected a CPU based tensor, got %s" % type(src))
def load( def load(filepath: str,
filepath, out: Optional[Tensor] = None,
out=None, normalization: Optional[bool] = True,
normalization=True, channels_first: Optional[bool] = True,
channels_first=True, num_frames: int = 0,
num_frames=0, offset: int = 0,
offset=0, signalinfo: SignalInfo = None,
signalinfo=None, encodinginfo: EncodingInfo = None,
encodinginfo=None, filetype: Optional[str] = None) -> Tuple[Tensor, int]:
filetype=None,
):
r"""See torchaudio.load""" r"""See torchaudio.load"""
assert out is None assert out is None
...@@ -96,7 +98,7 @@ def load( ...@@ -96,7 +98,7 @@ def load(
return out, sample_rate return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True): def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save""" r"""See torchaudio.save"""
ch_idx, len_idx = (0, 1) if channels_first else (1, 0) ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
...@@ -129,7 +131,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True): ...@@ -129,7 +131,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return soundfile.write(filepath, src, sample_rate, precision) return soundfile.write(filepath, src, sample_rate, precision)
def info(filepath): def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info""" r"""See torchaudio.info"""
import soundfile import soundfile
......
import os.path import os.path
from typing import Any, Optional, Tuple, Union
import torch import torch
from torch import Tensor
import torchaudio import torchaudio
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
def load(
filepath, def load(filepath: str,
out=None, out: Optional[Tensor] = None,
normalization=True, normalization: Optional[bool] = True,
channels_first=True, channels_first: Optional[bool] = True,
num_frames=0, num_frames: int = 0,
offset=0, offset: int = 0,
signalinfo=None, signalinfo: SignalInfo = None,
encodinginfo=None, encodinginfo: EncodingInfo = None,
filetype=None, filetype: Optional[str] = None) -> Tuple[Tensor, int]:
):
r"""See torchaudio.load""" r"""See torchaudio.load"""
# stringify if `pathlib.Path` (noop if already `str`) # stringify if `pathlib.Path` (noop if already `str`)
...@@ -53,7 +53,7 @@ def load( ...@@ -53,7 +53,7 @@ def load(
return out, sample_rate return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True): def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save""" r"""See torchaudio.save"""
si = torchaudio.sox_signalinfo_t() si = torchaudio.sox_signalinfo_t()
...@@ -65,7 +65,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True): ...@@ -65,7 +65,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return torchaudio.save_encinfo(filepath, src, channels_first, si) return torchaudio.save_encinfo(filepath, src, channels_first, si)
def info(filepath): def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info""" r"""See torchaudio.info"""
import _torch_sox import _torch_sox
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment