Unverified Commit 3383a2d9 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing init (#526)

* inline typing first iteration

* fix issue with mypy

* change docstring

* add more typing

* fix to not break BC

* update docstrings
parent af88b925
import os.path import os.path
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
from torch import Tensor
from torchaudio import ( from torchaudio import (
compliance, compliance,
datasets, datasets,
kaldi_io, kaldi_io,
sox_effects, sox_effects,
transforms, transforms
) )
from torchaudio._backend import ( from torchaudio._backend import (
check_input, check_input,
_audio_backend_guard, _audio_backend_guard,
...@@ -17,6 +18,7 @@ from torchaudio._backend import ( ...@@ -17,6 +18,7 @@ from torchaudio._backend import (
get_audio_backend, get_audio_backend,
set_audio_backend, set_audio_backend,
) )
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
try: try:
from .version import __version__, git_version # noqa: F401 from .version import __version__, git_version # noqa: F401
...@@ -24,27 +26,27 @@ except ImportError: ...@@ -24,27 +26,27 @@ except ImportError:
pass pass
def load(filepath, def load(filepath: Union[str, Path],
out=None, out: Optional[Tensor] = None,
normalization=True, normalization: Union[bool, float, Callable] = True,
channels_first=True, channels_first: bool = True,
num_frames=0, num_frames: int = 0,
offset=0, offset: int = 0,
signalinfo=None, signalinfo: Optional[SignalInfo] = None,
encodinginfo=None, encodinginfo: Optional[EncodingInfo] = None,
filetype=None): filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""Loads an audio file from disk into a tensor r"""Loads an audio file from disk into a tensor
Args: Args:
filepath (str or pathlib.Path): Path to audio file filepath (str or Path): Path to audio file
out (torch.Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``) out (Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31` normalization (bool, float, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[-1, 1]`. (assumes signed 32-bit audio), and normalizes to `[-1, 1]`.
If `number`, then output is divided by that number If `float`, then output is divided by that number
If `callable`, then the output is passed as a parameter If `Callable`, then the output is passed as a parameter
to the given function, then the output is divided by to the given function, then the output is divided by
the result. (Default: ``True``) the result. (Default: ``True``)
channels_first (bool): Set channels first or length first in result. (Default: ``True``) channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
num_frames (int, optional): Number of frames to load. 0 to load everything after the offset. num_frames (int, optional): Number of frames to load. 0 to load everything after the offset.
(Default: ``0``) (Default: ``0``)
offset (int, optional): Number of frames from the start of the file to begin data loading. offset (int, optional): Number of frames from the start of the file to begin data loading.
...@@ -57,7 +59,7 @@ def load(filepath, ...@@ -57,7 +59,7 @@ def load(filepath,
automatically. (Default: ``None``) automatically. (Default: ``None``)
Returns: Returns:
Tuple[torch.Tensor, int]: An output tensor of size `[C x L]` or `[L x C]` where L is the number (Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file) audio (as listed in the metadata of the file)
...@@ -86,15 +88,15 @@ def load(filepath, ...@@ -86,15 +88,15 @@ def load(filepath,
) )
def load_wav(filepath, **kwargs): def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits. the input right by 16 bits.
Args: Args:
filepath (str or pathlib.Path): Path to audio file filepath (str or Path): Path to audio file
Returns: Returns:
Tuple[torch.Tensor, int]: An output tensor of size `[C x L]` or `[L x C]` where L is the number (Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file) audio (as listed in the metadata of the file)
""" """
...@@ -102,17 +104,17 @@ def load_wav(filepath, **kwargs): ...@@ -102,17 +104,17 @@ def load_wav(filepath, **kwargs):
return load(filepath, **kwargs) return load(filepath, **kwargs)
def save(filepath, src, sample_rate, precision=16, channels_first=True): def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""Convenience function for `save_encinfo`. r"""Convenience function for `save_encinfo`.
Args: Args:
filepath (str): Path to audio file filepath (str): Path to audio file
src (torch.Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels the number of audio frames, C is the number of channels
sample_rate (int): An integer which is the sample rate of the sample_rate (int): An integer which is the sample rate of the
audio (as listed in the metadata of the file) audio (as listed in the metadata of the file)
precision (int): Bit precision (Default: ``16``) precision (int, optional): Bit precision (Default: ``16``)
channels_first (bool): Set channels first or length first in result. ( channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``) Default: ``True``)
""" """
...@@ -122,23 +124,23 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True): ...@@ -122,23 +124,23 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def save_encinfo(filepath, def save_encinfo(filepath: str,
src, src: Tensor,
channels_first=True, channels_first: bool = True,
signalinfo=None, signalinfo: Optional[SignalInfo] = None,
encodinginfo=None, encodinginfo: Optional[EncodingInfo] = None,
filetype=None): filetype: Optional[str] = None) -> None:
r"""Saves a tensor of an audio signal to disk as a standard format like mp3, wav, etc. r"""Saves a tensor of an audio signal to disk as a standard format like mp3, wav, etc.
Args: Args:
filepath (str): Path to audio file filepath (str): Path to audio file
src (torch.Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels the number of audio frames, C is the number of channels
channels_first (bool): Set channels first or length first in result. (Default: ``True``) channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
signalinfo (sox_signalinfo_t): A sox_signalinfo_t type, which could be helpful if the signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``) audio type cannot be automatically determined (Default: ``None``).
encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``) audio type cannot be automatically determined (Default: ``None``).
filetype (str, optional): A filetype or extension to be set if sox cannot determine it filetype (str, optional): A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``) automatically. (Default: ``None``)
...@@ -165,17 +167,18 @@ def save_encinfo(filepath, ...@@ -165,17 +167,18 @@ def save_encinfo(filepath,
"Expected format where C < 16, but found {}".format(src.size())) "Expected format where C < 16, but found {}".format(src.size()))
# sox stores the sample rate as a float, though practically sample rates are almost always integers # sox stores the sample rate as a float, though practically sample rates are almost always integers
# convert integers to floats # convert integers to floats
if not isinstance(signalinfo.rate, float): if signalinfo:
if float(signalinfo.rate) == signalinfo.rate: if not isinstance(signalinfo.rate, float):
signalinfo.rate = float(signalinfo.rate) if float(signalinfo.rate) == signalinfo.rate:
else: signalinfo.rate = float(signalinfo.rate)
raise TypeError('Sample rate should be a float or int') else:
# check if the bit precision (i.e. bits per sample) is an integer raise TypeError('Sample rate should be a float or int')
if not isinstance(signalinfo.precision, int): # check if the bit precision (i.e. bits per sample) is an integer
if int(signalinfo.precision) == signalinfo.precision: if not isinstance(signalinfo.precision, int):
signalinfo.precision = int(signalinfo.precision) if int(signalinfo.precision) == signalinfo.precision:
else: signalinfo.precision = int(signalinfo.precision)
raise TypeError('Bit precision should be an integer') else:
raise TypeError('Bit precision should be an integer')
# programs such as librosa normalize the signal, unnormalize if detected # programs such as librosa normalize the signal, unnormalize if detected
if src.min() >= -1.0 and src.max() <= 1.0: if src.min() >= -1.0 and src.max() <= 1.0:
src = src * (1 << 31) src = src * (1 << 31)
...@@ -193,14 +196,14 @@ def save_encinfo(filepath, ...@@ -193,14 +196,14 @@ def save_encinfo(filepath,
_torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype) _torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
def info(filepath): def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""Gets metadata from an audio file without loading the signal. r"""Gets metadata from an audio file without loading the signal.
Args: Args:
filepath (str): Path to audio file filepath (str): Path to audio file
Returns: Returns:
Tuple[sox_signalinfo_t, sox_encodinginfo_t]: A si (sox_signalinfo_t) signal (sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info info as a python object. An ei (sox_encodinginfo_t) encoding info
Example Example
...@@ -212,7 +215,7 @@ def info(filepath): ...@@ -212,7 +215,7 @@ def info(filepath):
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def sox_signalinfo_t(): def sox_signalinfo_t() -> SignalInfo:
r"""Create a sox_signalinfo_t object. This object can be used to set the sample r"""Create a sox_signalinfo_t object. This object can be used to set the sample
rate, number of channels, length, bit precision and headroom multiplier rate, number of channels, length, bit precision and headroom multiplier
primarily for effects primarily for effects
...@@ -237,7 +240,7 @@ def sox_signalinfo_t(): ...@@ -237,7 +240,7 @@ def sox_signalinfo_t():
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def sox_encodinginfo_t(): def sox_encodinginfo_t() -> EncodingInfo:
r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding
type, bit precision, compression factor, reverse bytes, reverse nibbles, type, bit precision, compression factor, reverse bytes, reverse nibbles,
reverse bits and endianness. This can be used in an effects chain to encode the reverse bits and endianness. This can be used in an effects chain to encode the
...@@ -277,7 +280,7 @@ def sox_encodinginfo_t(): ...@@ -277,7 +280,7 @@ def sox_encodinginfo_t():
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def get_sox_encoding_t(i=None): def get_sox_encoding_t(i: int = None) -> EncodingInfo:
r"""Get enum of sox_encoding_t for sox encodings. r"""Get enum of sox_encoding_t for sox encodings.
Args: Args:
...@@ -297,7 +300,7 @@ def get_sox_encoding_t(i=None): ...@@ -297,7 +300,7 @@ def get_sox_encoding_t(i=None):
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def get_sox_option_t(i=2): def get_sox_option_t(i: int = 2) -> Any:
r"""Get enum of sox_option_t for sox encodinginfo options. r"""Get enum of sox_option_t for sox encodinginfo options.
Args: Args:
...@@ -316,7 +319,7 @@ def get_sox_option_t(i=2): ...@@ -316,7 +319,7 @@ def get_sox_option_t(i=2):
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def get_sox_bool(i=0): def get_sox_bool(i: int = 0) -> Any:
r"""Get enum of sox_bool for sox encodinginfo options. r"""Get enum of sox_bool for sox encodinginfo options.
Args: Args:
...@@ -336,7 +339,7 @@ def get_sox_bool(i=0): ...@@ -336,7 +339,7 @@ def get_sox_bool(i=0):
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def initialize_sox(): def initialize_sox() -> int:
"""Initialize sox for use with effects chains. This is not required for simple """Initialize sox for use with effects chains. This is not required for simple
loading. Importantly, only run `initialize_sox` once and do not shutdown loading. Importantly, only run `initialize_sox` once and do not shutdown
after each effect chain, but rather once you are finished with all effects chains. after each effect chain, but rather once you are finished with all effects chains.
...@@ -347,7 +350,7 @@ def initialize_sox(): ...@@ -347,7 +350,7 @@ def initialize_sox():
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def shutdown_sox(): def shutdown_sox() -> int:
"""Showdown sox for effects chain. Not required for simple loading. Importantly, """Showdown sox for effects chain. Not required for simple loading. Importantly,
only call once. Attempting to re-initialize sox will result in seg faults. only call once. Attempting to re-initialize sox will result in seg faults.
""" """
...@@ -356,7 +359,7 @@ def shutdown_sox(): ...@@ -356,7 +359,7 @@ def shutdown_sox():
return _torch_sox.shutdown_sox() return _torch_sox.shutdown_sox()
def _audio_normalization(signal, normalization): def _audio_normalization(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool, """Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption. 32-bit signed integers internally, thus bool normalizes based on that assumption.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment