Unverified Commit 3383a2d9 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing init (#526)

* inline typing first iteration

* fix issue with mypy

* change docstring

* add more typing

* fix to not break BC

* update docstrings
parent af88b925
import os.path
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import torch
from torch import Tensor
from torchaudio import (
compliance,
datasets,
kaldi_io,
sox_effects,
transforms,
transforms
)
from torchaudio._backend import (
check_input,
_audio_backend_guard,
......@@ -17,6 +18,7 @@ from torchaudio._backend import (
get_audio_backend,
set_audio_backend,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
try:
from .version import __version__, git_version # noqa: F401
......@@ -24,27 +26,27 @@ except ImportError:
pass
def load(filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None):
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""Loads an audio file from disk into a tensor
Args:
filepath (str or pathlib.Path): Path to audio file
out (torch.Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
filepath (str or Path): Path to audio file
out (Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
normalization (bool, float, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[-1, 1]`.
If `number`, then output is divided by that number
If `callable`, then the output is passed as a parameter
If `float`, then output is divided by that number
If `Callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result. (Default: ``True``)
channels_first (bool): Set channels first or length first in result. (Default: ``True``)
channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
num_frames (int, optional): Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)
offset (int, optional): Number of frames from the start of the file to begin data loading.
......@@ -57,7 +59,7 @@ def load(filepath,
automatically. (Default: ``None``)
Returns:
Tuple[torch.Tensor, int]: An output tensor of size `[C x L]` or `[L x C]` where L is the number
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
......@@ -86,15 +88,15 @@ def load(filepath,
)
def load_wav(filepath, **kwargs):
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits.
Args:
filepath (str or pathlib.Path): Path to audio file
filepath (str or Path): Path to audio file
Returns:
Tuple[torch.Tensor, int]: An output tensor of size `[C x L]` or `[L x C]` where L is the number
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""
......@@ -102,17 +104,17 @@ def load_wav(filepath, **kwargs):
return load(filepath, **kwargs)
def save(filepath, src, sample_rate, precision=16, channels_first=True):
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""Convenience function for `save_encinfo`.
Args:
filepath (str): Path to audio file
src (torch.Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision (int): Bit precision (Default: ``16``)
channels_first (bool): Set channels first or length first in result. (
precision (int, optional): Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""
......@@ -122,23 +124,23 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
@_audio_backend_guard("sox")
def save_encinfo(filepath,
src,
channels_first=True,
signalinfo=None,
encodinginfo=None,
filetype=None):
def save_encinfo(filepath: str,
src: Tensor,
channels_first: bool = True,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> None:
r"""Saves a tensor of an audio signal to disk as a standard format like mp3, wav, etc.
Args:
filepath (str): Path to audio file
src (torch.Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
channels_first (bool): Set channels first or length first in result. (Default: ``True``)
signalinfo (sox_signalinfo_t): A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)
channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined (Default: ``None``).
encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)
audio type cannot be automatically determined (Default: ``None``).
filetype (str, optional): A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)
......@@ -165,17 +167,18 @@ def save_encinfo(filepath,
"Expected format where C < 16, but found {}".format(src.size()))
# sox stores the sample rate as a float, though practically sample rates are almost always integers
# convert integers to floats
if not isinstance(signalinfo.rate, float):
if float(signalinfo.rate) == signalinfo.rate:
signalinfo.rate = float(signalinfo.rate)
else:
raise TypeError('Sample rate should be a float or int')
# check if the bit precision (i.e. bits per sample) is an integer
if not isinstance(signalinfo.precision, int):
if int(signalinfo.precision) == signalinfo.precision:
signalinfo.precision = int(signalinfo.precision)
else:
raise TypeError('Bit precision should be an integer')
if signalinfo:
if not isinstance(signalinfo.rate, float):
if float(signalinfo.rate) == signalinfo.rate:
signalinfo.rate = float(signalinfo.rate)
else:
raise TypeError('Sample rate should be a float or int')
# check if the bit precision (i.e. bits per sample) is an integer
if not isinstance(signalinfo.precision, int):
if int(signalinfo.precision) == signalinfo.precision:
signalinfo.precision = int(signalinfo.precision)
else:
raise TypeError('Bit precision should be an integer')
# programs such as librosa normalize the signal, unnormalize if detected
if src.min() >= -1.0 and src.max() <= 1.0:
src = src * (1 << 31)
......@@ -193,14 +196,14 @@ def save_encinfo(filepath,
_torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
def info(filepath):
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""Gets metadata from an audio file without loading the signal.
Args:
filepath (str): Path to audio file
Returns:
Tuple[sox_signalinfo_t, sox_encodinginfo_t]: A si (sox_signalinfo_t) signal
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info
Example
......@@ -212,7 +215,7 @@ def info(filepath):
@_audio_backend_guard("sox")
def sox_signalinfo_t():
def sox_signalinfo_t() -> SignalInfo:
r"""Create a sox_signalinfo_t object. This object can be used to set the sample
rate, number of channels, length, bit precision and headroom multiplier
primarily for effects
......@@ -237,7 +240,7 @@ def sox_signalinfo_t():
@_audio_backend_guard("sox")
def sox_encodinginfo_t():
def sox_encodinginfo_t() -> EncodingInfo:
r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding
type, bit precision, compression factor, reverse bytes, reverse nibbles,
reverse bits and endianness. This can be used in an effects chain to encode the
......@@ -277,7 +280,7 @@ def sox_encodinginfo_t():
@_audio_backend_guard("sox")
def get_sox_encoding_t(i=None):
def get_sox_encoding_t(i: int = None) -> EncodingInfo:
r"""Get enum of sox_encoding_t for sox encodings.
Args:
......@@ -297,7 +300,7 @@ def get_sox_encoding_t(i=None):
@_audio_backend_guard("sox")
def get_sox_option_t(i=2):
def get_sox_option_t(i: int = 2) -> Any:
r"""Get enum of sox_option_t for sox encodinginfo options.
Args:
......@@ -316,7 +319,7 @@ def get_sox_option_t(i=2):
@_audio_backend_guard("sox")
def get_sox_bool(i=0):
def get_sox_bool(i: int = 0) -> Any:
r"""Get enum of sox_bool for sox encodinginfo options.
Args:
......@@ -336,7 +339,7 @@ def get_sox_bool(i=0):
@_audio_backend_guard("sox")
def initialize_sox():
def initialize_sox() -> int:
"""Initialize sox for use with effects chains. This is not required for simple
loading. Importantly, only run `initialize_sox` once and do not shutdown
after each effect chain, but rather once you are finished with all effects chains.
......@@ -347,7 +350,7 @@ def initialize_sox():
@_audio_backend_guard("sox")
def shutdown_sox():
def shutdown_sox() -> int:
"""Showdown sox for effects chain. Not required for simple loading. Importantly,
only call once. Attempting to re-initialize sox will result in seg faults.
"""
......@@ -356,7 +359,7 @@ def shutdown_sox():
return _torch_sox.shutdown_sox()
def _audio_normalization(signal, normalization):
def _audio_normalization(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment