Unverified Commit c29598d5 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing _backend (#527)

* add inline typing

* fix error

* minor change

* minor fix
parent 98fe8b46
from functools import wraps
from typing import Any, List, Union
import platform
import torch
from torch import Tensor
from . import _soundfile_backend, _sox_backend
......@@ -10,11 +12,11 @@ _audio_backend = "soundfile" if platform.system() == "Windows" else "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def set_audio_backend(backend):
def set_audio_backend(backend: str) -> None:
"""
Specifies the package used to load.
Args:
backend (string): Name of the backend. One of {}.
backend (str): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
......@@ -24,14 +26,14 @@ def set_audio_backend(backend):
_audio_backend = backend
def get_audio_backend():
def get_audio_backend() -> str:
"""
Gets the name of the package used to load.
"""
return _audio_backend
def _get_audio_backend_module():
def _get_audio_backend_module() -> Any:
"""
Gets the module backend to load.
"""
......@@ -39,7 +41,7 @@ def _get_audio_backend_module():
return _audio_backends[backend]
def _audio_backend_guard(backends):
def _audio_backend_guard(backends: Union[str, List[str]]) -> Any:
if isinstance(backends, str):
backends = [backends]
......@@ -55,7 +57,7 @@ def _audio_backend_guard(backends):
return decorator
def check_input(src):
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
......
import os
from typing import Any, Optional, Tuple, Union
import torch
from torch import Tensor
_subtype_to_precision = {
'PCM_S8': 8,
......@@ -12,7 +14,11 @@ _subtype_to_precision = {
class SignalInfo:
def __init__(self, channels=None, rate=None, precision=None, length=None):
def __init__(self,
channels: Optional[int] = None,
rate: Optional[float] = None,
precision: Optional[int] = None,
length: Optional[int] = None) -> None:
self.channels = channels
self.rate = rate
self.precision = precision
......@@ -20,16 +26,14 @@ class SignalInfo:
class EncodingInfo:
def __init__(
self,
encoding=None,
bits_per_sample=None,
compression=None,
reverse_bytes=None,
reverse_nibbles=None,
reverse_bits=None,
opposite_endian=None
):
def __init__(self,
encoding: Any = None,
bits_per_sample: Optional[int] = None,
compression: Optional[float] = None,
reverse_bytes: Any = None,
reverse_nibbles: Any = None,
reverse_bits: Any = None,
opposite_endian: Optional[bool] = None) -> None:
self.encoding = encoding
self.bits_per_sample = bits_per_sample
self.compression = compression
......@@ -39,24 +43,22 @@ class EncodingInfo:
self.opposite_endian = opposite_endian
def check_input(src):
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))
def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
encodinginfo: EncodingInfo = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""See torchaudio.load"""
assert out is None
......@@ -96,7 +98,7 @@ def load(
return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""
ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
......@@ -129,7 +131,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return soundfile.write(filepath, src, sample_rate, precision)
def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""
import soundfile
......
import os.path
from typing import Any, Optional, Tuple, Union
import torch
from torch import Tensor
import torchaudio
def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
encodinginfo: EncodingInfo = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""See torchaudio.load"""
# stringify if `pathlib.Path` (noop if already `str`)
......@@ -53,7 +53,7 @@ def load(
return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""
si = torchaudio.sox_signalinfo_t()
......@@ -65,7 +65,7 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return torchaudio.save_encinfo(filepath, src, channels_first, si)
def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""
import _torch_sox
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment