Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
import warnings
import importlib.util
from typing import Optional
from functools import wraps
import torch
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
return all(importlib.util.find_spec(m) is not None for m in modules)
def requires_module(*modules: str):
"""Decorate function to give error message if invoked without required optional modules.
This decorator is to give better error message to users rather
than raising ``NameError: name 'module' is not defined`` at random places.
"""
missing = [m for m in modules if not is_module_available(m)]
if not missing:
# fall through. If all the modules are available, no need to decorate
def decorator(func):
return func
else:
req = f'module: {missing[0]}' if len(missing) == 1 else f'modules: {missing}'
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires {req}')
return wrapped
return decorator
def deprecated(direction: str, version: Optional[str] = None):
"""Decorator to add deprecation message
Args:
direction (str): Migration steps to be given to users.
version (str or int): The version when the object will be removed
"""
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
message = (
f'{func.__module__}.{func.__name__} has been deprecated '
f'and will be removed from {"future" if version is None else version} release. '
f'{direction}')
warnings.warn(message, stacklevel=2)
return func(*args, **kwargs)
return wrapped
return decorator
def is_kaldi_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available()
def requires_kaldi():
if is_kaldi_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi')
return wrapped
return decorator
def _check_soundfile_importable():
if not is_module_available('soundfile'):
return False
try:
import soundfile # noqa: F401
return True
except Exception:
warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.")
return False
_is_soundfile_importable = _check_soundfile_importable()
def is_soundfile_available():
return _is_soundfile_importable
def requires_soundfile():
if is_soundfile_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires soundfile')
return wrapped
return decorator
def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
def requires_sox():
if is_sox_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires sox')
return wrapped
return decorator
# flake8: noqa
from . import utils
from .utils import (
list_audio_backends,
get_audio_backend,
set_audio_backend,
)
utils._init_audio_backend()
class AudioMetaData:
"""Return type of ``torchaudio.info`` function.
This class is used by :ref:`"sox_io" backend<sox_io_backend>` and
:ref:`"soundfile" backend with the new interface<soundfile_backend>`.
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
:ivar str encoding: Audio encoding
The values encoding can take are one of the following:
* ``PCM_S``: Signed integer linear PCM
* ``PCM_U``: Unsigned integer linear PCM
* ``PCM_F``: Floating point linear PCM
* ``FLAC``: Flac, Free Lossless Audio Codec
* ``ULAW``: Mu-law
* ``ALAW``: A-law
* ``MP3`` : MP3, MPEG-1 Audio Layer III
* ``VORBIS``: OGG Vorbis
* ``AMR_WB``: Adaptive Multi-Rate
* ``AMR_NB``: Adaptive Multi-Rate Wideband
* ``OPUS``: Opus
* ``UNKNOWN`` : None of above
"""
def __init__(
self,
sample_rate: int,
num_frames: int,
num_channels: int,
bits_per_sample: int,
encoding: str,
):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample
self.encoding = encoding
def __str__(self):
return (
f"AudioMetaData("
f"sample_rate={self.sample_rate}, "
f"num_frames={self.num_frames}, "
f"num_channels={self.num_channels}, "
f"bits_per_sample={self.bits_per_sample}, "
f"encoding={self.encoding}"
f")"
)
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
raise RuntimeError('No audio I/O backend is available.')
def info(filepath: str) -> None:
raise RuntimeError('No audio I/O backend is available.')
"""The new soundfile backend which will become default in 0.8.0 onward"""
from typing import Tuple, Optional
import warnings
import torch
from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData
if _mod_utils.is_soundfile_available():
import soundfile
# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
'PCM_S8': 8, # Signed 8 bit data
'PCM_16': 16, # Signed 16 bit data
'PCM_24': 24, # Signed 24 bit data
'PCM_32': 32, # Signed 32 bit data
'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only)
'FLOAT': 32, # 32 bit float data
'DOUBLE': 64, # 64 bit float data
'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'IMA_ADPCM': 0, # IMA ADPCM.
'MS_ADPCM': 0, # Microsoft ADPCM.
'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
'VOX_ADPCM': 0, # OKI / Dialogix ADPCM
'G721_32': 0, # 32kbs G721 ADPCM encoding.
'G723_24': 0, # 24kbs G723 ADPCM encoding.
'G723_40': 0, # 40kbs G723 ADPCM encoding.
'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding.
'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding.
'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding.
'DWVW_N': 0, # N bit Delta Width Variable Word encoding.
'DPCM_8': 8, # 8 bit differential PCM (XI only)
'DPCM_16': 16, # 16 bit differential PCM (XI only)
'VORBIS': 0, # Xiph Vorbis encoding. (lossy)
'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit).
'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit).
'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit).
'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit).
}
def _get_bit_depth(subtype):
if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
_SUBTYPE_TO_ENCODING = {
'PCM_S8': 'PCM_S',
'PCM_16': 'PCM_S',
'PCM_24': 'PCM_S',
'PCM_32': 'PCM_S',
'PCM_U8': 'PCM_U',
'FLOAT': 'PCM_F',
'DOUBLE': 'PCM_F',
'ULAW': 'ULAW',
'ALAW': 'ALAW',
'VORBIS': 'VORBIS',
}
def _get_encoding(format: str, subtype: str):
if format == 'FLAC':
return 'FLAC'
return _SUBTYPE_TO_ENCODING.get(subtype, 'UNKNOWN')
@_mod_utils.requires_soundfile()
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Note:
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
which has a restriction on type annotation due to TorchScript compiler compatiblity.
Args:
filepath (path-like object or file-like object):
Source of audio data.
format (str or None, optional):
Not used. PySoundFile does not accept format hint.
Returns:
AudioMetaData: meta data of the given audio.
"""
sinfo = soundfile.info(filepath)
return AudioMetaData(
sinfo.samplerate,
sinfo.frames,
sinfo.channels,
bits_per_sample=_get_bit_depth(sinfo.subtype),
encoding=_get_encoding(sinfo.format, sinfo.subtype),
)
_SUBTYPE2DTYPE = {
"PCM_S8": "int8",
"PCM_U8": "uint8",
"PCM_16": "int16",
"PCM_32": "int32",
"FLOAT": "float32",
"DOUBLE": "float64",
}
@_mod_utils.requires_soundfile()
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
Note:
The formats this function can handle depend on the soundfile installation.
This function is tested on the following formats;
* WAV
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* FLAC
* OGG/VORBIS
* SPHERE
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
``float32`` dtype and the shape of `[channel, time]`.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
by providing ``normalize=False``, this function can return integer Tensor, where the samples
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
``flac`` and ``mp3``.
For these formats, this function always returns ``float32`` Tensor with values normalized to
``[-1.0, 1.0]``.
Note:
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
which has a restriction on type annotation due to TorchScript compiler compatiblity.
Args:
filepath (path-like object or file-like object):
Source of audio data.
frame_offset (int, optional):
Number of frames to skip before start reading data.
num_frames (int, optional):
Maximum number of frames to read. ``-1`` reads all the remaining samples,
starting from ``frame_offset``.
This function may return the less number of frames if there is not enough
frames in the given file.
normalize (bool, optional):
When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``.
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
integer type.
This argument has no effect for formats other than integer WAV type.
channels_first (bool, optional):
When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
Not used. PySoundFile does not accept format hint.
Returns:
(torch.Tensor, int): Resulting Tensor and sample rate.
If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
`[channel, time]` else `[time, channel]`.
"""
with soundfile.SoundFile(filepath, "r") as file_:
if file_.format != "WAV" or normalize:
dtype = "float32"
elif file_.subtype not in _SUBTYPE2DTYPE:
raise ValueError(f"Unsupported subtype: {file_.subtype}")
else:
dtype = _SUBTYPE2DTYPE[file_.subtype]
frames = file_._prepare_read(frame_offset, None, num_frames)
waveform = file_.read(frames, dtype, always_2d=True)
sample_rate = file_.samplerate
waveform = torch.from_numpy(waveform)
if channels_first:
waveform = waveform.t()
return waveform, sample_rate
def _get_subtype_for_wav(
dtype: torch.dtype,
encoding: str,
bits_per_sample: int):
if not encoding:
if not bits_per_sample:
subtype = {
torch.uint8: "PCM_U8",
torch.int16: "PCM_16",
torch.int32: "PCM_32",
torch.float32: "FLOAT",
torch.float64: "DOUBLE",
}.get(dtype)
if not subtype:
raise ValueError(f"Unsupported dtype for wav: {dtype}")
return subtype
if bits_per_sample == 8:
return "PCM_U8"
return f"PCM_{bits_per_sample}"
if encoding == "PCM_S":
if not bits_per_sample:
return "PCM_32"
if bits_per_sample == 8:
raise ValueError("wav does not support 8-bit signed PCM encoding.")
return f"PCM_{bits_per_sample}"
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "PCM_U8"
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
if encoding == "PCM_F":
if bits_per_sample in (None, 32):
return "FLOAT"
if bits_per_sample == 64:
return "DOUBLE"
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("wav only supports 8-bit mu-law encoding.")
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "ALAW"
raise ValueError("wav only supports 8-bit a-law encoding.")
raise ValueError(f"wav does not support {encoding}.")
def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
if encoding in (None, "PCM_S"):
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
if encoding in ("PCM_U", "PCM_F"):
raise ValueError(f"sph does not support {encoding} encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("sph only supports 8-bit for mu-law encoding.")
if encoding == "ALAW":
return "ALAW"
raise ValueError(f"sph does not support {encoding}.")
def _get_subtype(
dtype: torch.dtype,
format: str,
encoding: str,
bits_per_sample: int):
if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
if encoding:
raise ValueError("flac does not support encoding.")
if not bits_per_sample:
return "PCM_16"
if bits_per_sample > 24:
raise ValueError("flac does not support bits_per_sample > 24.")
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError(
"ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS"
if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample)
if format in ("nis", "nist"):
return "PCM_16"
raise ValueError(f"Unsupported format: {format}")
@_mod_utils.requires_soundfile()
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.
Note:
The formats this function can handle depend on the soundfile installation.
This function is tested on the following formats;
* WAV
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* FLAC
* OGG/VORBIS
* SPHERE
Note:
``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
which has a restriction on type annotation due to TorchScript compiler compatiblity.
Args:
filepath (str or pathlib.Path): Path to audio file.
src (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
otherwise `[time, channel]`.
compression (float of None, optional): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
format (str or None, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.
When ``filepath`` argument is file-like object,
this argument is required.
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str or None, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
bits_per_sample (int or None, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
Supported formats/encodings/bit depth/compression are:
``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of
the input Tensor.
``"flac"``
- 8-bit
- 16-bit (default)
- 24-bit
``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
if compression is not None:
warnings.warn(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
ext = format.lower()
else:
ext = str(filepath).split(".")[-1].lower()
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
raise ValueError("Invalid bits_per_sample.")
if bits_per_sample == 24:
warnings.warn("Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this.")
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
if ext in ["nis", "nist", "sph"] and format is None:
format = "NIST"
if channels_first:
src = src.t()
soundfile.write(
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
)
import os
from typing import Tuple, Optional
import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)
import torchaudio
from .common import AudioMetaData
@_mod_utils.requires_sox()
def info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str or None, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
from header or extension,
Returns:
AudioMetaData: Metadata of the given audio.
"""
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(*sinfo)
@_mod_utils.requires_sox()
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
Note:
This function can handle all the codecs that underlying libsox can handle,
however it is tested on the following formats;
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
* 24-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer (WAV only)
* MP3
* FLAC
* OGG/VORBIS
* OPUS
* SPHERE
* AMR-NB
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
``float32`` dtype and the shape of `[channel, time]`.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
this function can return integer Tensor, where the samples are expressed within the whole range
of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
``flac`` and ``mp3``.
For these formats, this function always returns ``float32`` Tensor with values normalized to
``[-1.0, 1.0]``.
Args:
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note: This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int, optional):
Maximum number of frames to read. ``-1`` reads all the remaining samples,
starting from ``frame_offset``.
This function may return the less number of frames if there is not enough
frames in the given file.
normalize (bool, optional):
When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``.
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
integer type.
This argument has no effect for formats other than integer WAV type.
channels_first (bool, optional):
When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
from header or extension,
Returns:
(torch.Tensor, int): Resulting Tensor and sample rate.
If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
`[channel, time]` else `[time, channel]`.
"""
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
return torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
return torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox()
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.
Args:
filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated
as ``str`` for TorchScript compiler compatibility.
src (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
otherwise `[time, channel]`.
compression (float or None, optional): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command.
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html.
format (str or None, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is infered from
file extension. If file extension is missing or different, you can specify the
correct format with this argument.
When ``filepath`` argument is file-like object, this argument is required.
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
encoding (str or None, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``
- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise
``"sph"`` format;
- the default value is ``"PCM_S"``
bits_per_sample (int or None, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``
``"flac"`` format;
- the default value is ``24``
``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``
Supported formats/encodings/bit depth/compression are;
``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
``"gsm"``
Lossy Speech Compression, CPU intensive.
``"htk"``
Uses a default single-channel 16-bit PCM format.
Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
"""
if not torch.jit.is_scripting():
if hasattr(filepath, 'write'):
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression,
format, encoding, bits_per_sample)
return
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
"""Defines utilities for switching audio backends"""
import warnings
from typing import Optional, List
import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from . import (
no_backend,
sox_io_backend,
soundfile_backend,
)
__all__ = [
'list_audio_backends',
'get_audio_backend',
'set_audio_backend',
]
def list_audio_backends() -> List[str]:
"""List available backends
Returns:
List[str]: The list of available backends.
"""
backends = []
if _mod_utils.is_module_available('soundfile'):
backends.append('soundfile')
if _mod_utils.is_sox_available():
backends.append('sox_io')
return backends
def set_audio_backend(backend: Optional[str]):
"""Set the backend for I/O operation
Args:
backend (str or None): Name of the backend.
One of ``"sox_io"`` or ``"soundfile"`` based on availability
of the system. If ``None`` is provided the current backend is unassigned.
"""
if backend is not None and backend not in list_audio_backends():
raise RuntimeError(
f'Backend "{backend}" is not one of '
f'available backends: {list_audio_backends()}.')
if backend is None:
module = no_backend
elif backend == 'sox_io':
module = sox_io_backend
elif backend == 'soundfile':
module = soundfile_backend
else:
raise NotImplementedError(f'Unexpected backend "{backend}"')
for func in ['save', 'load', 'info']:
setattr(torchaudio, func, getattr(module, func))
def _init_audio_backend():
backends = list_audio_backends()
if 'sox_io' in backends:
set_audio_backend('sox_io')
elif 'soundfile' in backends:
set_audio_backend('soundfile')
else:
warnings.warn('No audio backend is available.')
set_audio_backend(None)
def get_audio_backend() -> Optional[str]:
"""Get the name of the current backend
Returns:
Optional[str]: The name of the current backend or ``None`` if no backend is assigned.
"""
if torchaudio.load == no_backend.load:
return None
if torchaudio.load == sox_io_backend.load:
return 'sox_io'
if torchaudio.load == soundfile_backend.load:
return 'soundfile'
raise ValueError('Unknown backend.')
from . import kaldi
__all__ = [
'kaldi',
]
from typing import Tuple
import math
import torch
from torch import Tensor
import torchaudio
__all__ = [
'get_mel_banks',
'inverse_mel_scale',
'inverse_mel_scale_scalar',
'mel_scale',
'mel_scale_scalar',
'spectrogram',
'fbank',
'mfcc',
'vtln_warp_freq',
'vtln_warp_mel_freq',
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001
# window types
HAMMING = 'hamming'
HANNING = 'hanning'
POVEY = 'povey'
RECTANGULAR = 'rectangular'
BLACKMAN = 'blackman'
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)
def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x
"""
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
representing how the window is shifted along the waveform. Each row is a frame.
Args:
waveform (Tensor): Tensor of size ``num_samples``
window_size (int): Frame length
window_shift (int): Frame shift
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends.
Returns:
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
"""
assert waveform.dim() == 1
num_samples = waveform.size(0)
strides = (window_shift * waveform.stride(0), waveform.stride(0))
if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
reversed_waveform = torch.flip(waveform, [0])
m = (num_samples + (window_shift // 2)) // window_shift
pad = window_size // 2 - window_shift // 2
pad_right = reversed_waveform
if pad > 0:
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
# but we want [2, 1, 0, 0, 1, 2]
pad_left = reversed_waveform[-pad:]
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
else:
# pad is negative so we want to trim the waveform at the front
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
sizes = (m, window_size)
return waveform.as_strided(sizes, strides)
def _feature_window_function(window_type: str,
window_size: int,
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size
"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (blackman_coeff - 0.5 * torch.cos(a * window_function) +
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype)
else:
raise Exception('Invalid window type ' + window_type)
def _get_log_energy(strided_input: Tensor,
epsilon: Tensor,
energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)
"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
return torch.max(
log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties
"""
channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0)))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(
waveform), ('choose a window size {} that is [2, {}]'
.format(window_size, len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
return waveform, window_shift, window_size, padded_window_size
def _get_window(waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
r"""Gets a window and its log energy
Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0:
# Returns a random number strictly between 0 and 1
x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype))
rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (1, 0), mode='replicate').squeeze(0) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode='constant', value=0).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
return strided_input, signal_log_energy
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
tensor = tensor - col_means
return tensor
def spectrogram(waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
window_type: str = POVEY) -> Tensor:
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A spectrogram identical to what Kaldi would output. The shape is
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
return power_spectrum
def inverse_mel_scale_scalar(mel_freq: float) -> float:
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
def mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def vtln_warp_freq(vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor) -> Tensor:
r"""This computes a VTLN warping function that is not the same as HTK's one,
but has similar inputs (this function has the advantage of never producing
empty bins).
This function computes a warp function F(freq), defined between low_freq
and high_freq inclusive, with the following properties:
F(low_freq) == low_freq
F(high_freq) == high_freq
The function is continuous and piecewise linear with two inflection
points.
The lower inflection point (measured in terms of the unwarped
frequency) is at frequency l, determined as described below.
The higher inflection point is at a frequency h, determined as
described below.
If l <= f <= h, then F(f) = f/vtln_warp_factor.
If the higher inflection point (measured in terms of the unwarped
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
Since (by the last point) F(h) == h/vtln_warp_factor, then
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
= vtln_high_cutoff * min(1, vtln_warp_factor).
If the lower inflection point (measured in terms of the unwarped
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
= vtln_low_cutoff * max(1, vtln_warp_factor)
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
freq (Tensor): given frequency in Hz
Returns:
Tensor: Freq after vtln warp
"""
assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
Fl = scale * l # F(l)
Fh = scale * h # F(h)
assert l > low_freq and h < high_freq
# slope of left part of the 3-piece linear function
scale_left = (Fl - low_freq) / (l - low_freq)
# [slope of center part is just "scale"]
# slope of right part of the 3-piece linear function
scale_right = (high_freq - Fh) / (high_freq - h)
res = torch.empty_like(freq)
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
before_l = torch.lt(freq, l) # freq < l
before_h = torch.lt(freq, h) # freq < h
after_h = torch.ge(freq, h) # freq >= h
# order of operations matter here (since there is overlapping frequency regions)
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
res[before_h] = scale * freq[before_h]
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
res[outside_low_high_freq] = freq[outside_low_high_freq]
return res
def vtln_warp_mel_freq(vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq, high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor) -> Tensor:
r"""
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
mel_freq (Tensor): Given frequency in Mel
Returns:
Tensor: ``mel_freq`` after vtln warp
"""
return mel_scale(vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
vtln_warp_factor, inverse_mel_scale(mel_freq)))
def get_mel_banks(num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
"""
Returns:
(Tensor, Tensor): The tuple consists of ``bins`` (which is
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
center frequencies of bins of size (``num_bins``)).
"""
assert num_bins > 3, 'Must have at least 3 mel bins'
assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq
if high_freq <= 0.0:
high_freq += nyquist
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# divide by num_bins+1 in next line because of end-effects where the bins
# spread out to the sides.
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
if vtln_high < 0.0:
vtln_high += nyquist
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
if vtln_warp_factor != 1.0:
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
if vtln_warp_factor == 1.0:
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
else:
# warping can move the order of left_mel, center_mel, right_mel anywhere
bins = torch.zeros_like(up_slope)
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx]
return bins, center_freqs
def fbank(waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY) -> Tensor:
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
compute-fbank-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
(need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0, device=device, dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1)
spectrum = torch.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.)
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
if htk_compat:
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
else:
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, 'ortho')
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
def mfcc(
waveform: Tensor,
blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
num_ceps: int = 13,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY) -> Tensor:
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
features (need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``"povey"``)
Returns:
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (num_ceps, num_mel_bins)
device, dtype = waveform.device, waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(waveform=waveform, blackman_coeff=blackman_coeff, channel=channel,
dither=dither, energy_floor=energy_floor, frame_length=frame_length,
frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat,
low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False,
use_energy=use_energy, use_log_fbank=True, use_power=True,
vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
# size (m, num_ceps)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.to(device=device, dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = torch.cat((feature, energy), dim=1)
feature = _subtract_column_mean(feature, subtract_mean)
return feature
get_property(TORCHAUDIO_THIRD_PARTIES GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES)
################################################################################
# libtorchaudio
################################################################################
set(
LIBTORCHAUDIO_SOURCES
lfilter.cpp
overdrive.cpp
utils.cpp
)
if(BUILD_RNNT)
list(
APPEND
LIBTORCHAUDIO_SOURCES
rnnt/cpu/compute_alphas.cpp
rnnt/cpu/compute_betas.cpp
rnnt/cpu/compute.cpp
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
rnnt/autograd.cpp
)
if (USE_CUDA)
list(
APPEND
LIBTORCHAUDIO_SOURCES
rnnt/gpu/compute_alphas.cu
rnnt/gpu/compute_betas.cu
rnnt/gpu/compute.cu
)
endif()
endif()
if(BUILD_KALDI)
list(APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp)
endif()
if(BUILD_SOX)
list(
APPEND
LIBTORCHAUDIO_SOURCES
sox/io.cpp
sox/utils.cpp
sox/effects.cpp
sox/effects_chain.cpp
sox/types.cpp
)
endif()
add_library(
libtorchaudio
SHARED
${LIBTORCHAUDIO_SOURCES}
)
set_target_properties(libtorchaudio PROPERTIES PREFIX "")
target_include_directories(
libtorchaudio
PRIVATE
${PROJECT_SOURCE_DIR}
)
target_link_libraries(
libtorchaudio
torch
${TORCHAUDIO_THIRD_PARTIES}
)
if (BUILD_SOX)
target_compile_definitions(libtorchaudio PUBLIC INCLUDE_SOX)
endif()
if (BUILD_KALDI)
target_compile_definitions(libtorchaudio PUBLIC INCLUDE_KALDI)
endif()
if(USE_CUDA)
target_compile_definitions(libtorchaudio PRIVATE USE_CUDA)
target_include_directories(
libtorchaudio
PRIVATE
${CUDA_TOOLKIT_INCLUDE}
)
target_link_libraries(
libtorchaudio
${C10_CUDA_LIBRARY}
${CUDA_CUDART_LIBRARY}
)
endif()
if (MSVC)
set_target_properties(libtorchaudio PROPERTIES SUFFIX ".pyd")
endif(MSVC)
install(
TARGETS libtorchaudio
LIBRARY DESTINATION lib
RUNTIME DESTINATION lib # For Windows
)
if (APPLE)
set(TORCHAUDIO_LIBRARY libtorchaudio CACHE INTERNAL "")
else()
set(TORCHAUDIO_LIBRARY -Wl,--no-as-needed libtorchaudio -Wl,--as-needed CACHE INTERNAL "")
endif()
################################################################################
# _torchaudio.so
################################################################################
if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
set(
EXTENSION_SOURCES
pybind/pybind.cpp
)
if(BUILD_SOX)
list(
APPEND
EXTENSION_SOURCES
pybind/sox/effects.cpp
pybind/sox/effects_chain.cpp
pybind/sox/io.cpp
pybind/sox/utils.cpp
)
endif()
add_library(
_torchaudio
SHARED
${EXTENSION_SOURCES}
)
set_target_properties(_torchaudio PROPERTIES PREFIX "")
if (MSVC)
set_target_properties(_torchaudio PROPERTIES SUFFIX ".pyd")
endif(MSVC)
if (APPLE)
# https://github.com/facebookarchive/caffe2/issues/854#issuecomment-364538485
# https://github.com/pytorch/pytorch/commit/73f6715f4725a0723d8171d3131e09ac7abf0666
set_target_properties(_torchaudio PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()
target_include_directories(
_torchaudio
PRIVATE
${PROJECT_SOURCE_DIR}
${Python_INCLUDE_DIR}
)
# See https://github.com/pytorch/pytorch/issues/38122
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
if (WIN32)
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
set(ADDITIONAL_ITEMS Python3::Python)
endif()
target_link_libraries(
_torchaudio
libtorchaudio
${TORCH_PYTHON_LIBRARY}
${ADDITIONAL_ITEMS}
)
install(
TARGETS _torchaudio
LIBRARY DESTINATION .
RUNTIME DESTINATION . # For Windows
)
endif()
#include <torch/script.h>
#include "feat/pitch-functions.h"
namespace torchaudio {
namespace kaldi {
namespace {
torch::Tensor denormalize(const torch::Tensor& t) {
auto ret = t;
auto pos = t > 0, neg = t < 0;
ret.index_put({pos}, t.index({pos}) * 32767);
ret.index_put({neg}, t.index({neg}) * 32768);
return ret;
}
torch::Tensor compute_kaldi_pitch(
const torch::Tensor& wave,
const ::kaldi::PitchExtractionOptions& opts) {
::kaldi::VectorBase<::kaldi::BaseFloat> input(wave);
::kaldi::Matrix<::kaldi::BaseFloat> output;
::kaldi::ComputeKaldiPitch(opts, input, &output);
return output.tensor_;
}
} // namespace
torch::Tensor ComputeKaldiPitch(
const torch::Tensor& wave,
double sample_frequency,
double frame_length,
double frame_shift,
double min_f0,
double max_f0,
double soft_min_f0,
double penalty_factor,
double lowpass_cutoff,
double resample_frequency,
double delta_pitch,
double nccf_ballast,
int64_t lowpass_filter_width,
int64_t upsample_filter_width,
int64_t max_frames_latency,
int64_t frames_per_chunk,
bool simulate_first_pass_online,
int64_t recompute_frame,
bool snip_edges) {
TORCH_CHECK(wave.ndimension() == 2, "Input tensor must be 2 dimentional.");
TORCH_CHECK(wave.device().is_cpu(), "Input tensor must be on CPU.");
TORCH_CHECK(
wave.dtype() == torch::kFloat32, "Input tensor must be float32 type.");
::kaldi::PitchExtractionOptions opts;
opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency);
opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift);
opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length);
opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0);
opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0);
opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0);
opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor);
opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff);
opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency);
opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch);
opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width);
opts.upsample_filter_width =
static_cast<::kaldi::int32>(upsample_filter_width);
opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency);
opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk);
opts.simulate_first_pass_online = simulate_first_pass_online;
opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame);
opts.snip_edges = snip_edges;
// Kaldi's float type expects value range of int16 expressed as float
torch::Tensor wave_ = denormalize(wave);
auto batch_size = wave_.size(0);
std::vector<torch::Tensor> results(batch_size);
at::parallel_for(0, batch_size, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; ++i) {
results[i] = compute_kaldi_pitch(wave_.index({i}), opts);
}
});
return torch::stack(results, 0);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::kaldi_ComputeKaldiPitch",
&torchaudio::kaldi::ComputeKaldiPitch);
}
} // namespace kaldi
} // namespace torchaudio
#include <torch/script.h>
#include <torch/torch.h>
namespace {
template <typename scalar_t>
void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_batch = input_signal_windows.size(0);
int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
int64_t offset_input = i * n_samples_input;
int64_t offset_output = i * n_samples_output;
int64_t i_channel = i % n_channel;
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff + i_channel * n_order];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
});
}
void cpu_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
TORCH_CHECK(
input_signal_windows.device().is_cpu() &&
a_coeff_flipped.device().is_cpu() &&
padded_output_waveform.device().is_cpu());
TORCH_CHECK(
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
padded_output_waveform.is_contiguous());
TORCH_CHECK(
(input_signal_windows.dtype() == torch::kFloat32 ||
input_signal_windows.dtype() == torch::kFloat64) &&
(a_coeff_flipped.dtype() == torch::kFloat32 ||
a_coeff_flipped.dtype() == torch::kFloat64) &&
(padded_output_waveform.dtype() == torch::kFloat32 ||
padded_output_waveform.dtype() == torch::kFloat64));
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
TORCH_CHECK(
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(2));
AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
host_lfilter_core_loop<scalar_t>(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
});
}
void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal =
padded_output_waveform
.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)})
.transpose(0, 1);
auto o0 =
input_signal_windows.index(
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
}
}
class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1;
auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_batch, n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
}
auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
ctx->save_for_backward({waveform, a_coeffs_normalized, output});
return output;
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto a_coeffs_normalized = saved[1];
auto y = saved[2];
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);
auto dx = torch::Tensor();
auto da = torch::Tensor();
auto dy = grad_outputs[0];
namespace F = torch::nn::functional;
if (a_coeffs_normalized.requires_grad()) {
auto dyda = F::pad(
DifferentiableIIR::apply(-y, a_coeffs_normalized),
F::PadFuncOptions({n_order - 1, 0}));
da = F::conv1d(
dyda.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}
if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2);
}
return {dx, da};
}
};
class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);
namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
ctx->save_for_backward({waveform, b_coeffs, output});
return output;
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);
auto dx = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];
namespace F = torch::nn::functional;
if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}
if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}
return {dx, db};
}
};
torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));
int64_t n_order = b_coeffs.size(1);
TORCH_INTERNAL_ASSERT(n_order > 0);
auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
}
} // namespace
// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}
TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
}
TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
}
#include <torch/script.h>
#include <torch/torch.h>
namespace {
template <typename scalar_t>
void overdrive_cpu_kernel(
at::TensorAccessor<scalar_t, 2> waveform_accessor,
at::TensorAccessor<scalar_t, 2> temp_accessor,
at::TensorAccessor<scalar_t, 1> last_in_accessor,
at::TensorAccessor<scalar_t, 1> last_out_accessor,
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
int64_t n_frames = waveform_accessor.size(1);
int64_t n_channels = waveform_accessor.size(0);
at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
output_waveform_accessor[i_channel][i_frame] =
waveform_accessor[i_channel][i_frame] * 0.5 +
last_out_accessor[i_channel] * 0.75;
}
}
});
}
void overdrive_core_loop_cpu(
at::Tensor& waveform,
at::Tensor& temp,
at::Tensor& last_in,
at::Tensor& last_out,
at::Tensor& output_waveform) {
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
overdrive_cpu_kernel<scalar_t>(
waveform.accessor<scalar_t, 2>(),
temp.accessor<scalar_t, 2>(),
last_in.accessor<scalar_t, 1>(),
last_out.accessor<scalar_t, 1>(),
output_waveform.accessor<scalar_t, 2>());
}));
}
} // namespace
// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
}
#include <torch/extension.h>
#ifdef INCLUDE_SOX
#include <torchaudio/csrc/pybind/sox/effects.h>
#include <torchaudio/csrc/pybind/sox/io.h>
#endif
PYBIND11_MODULE(_torchaudio, m) {
#ifdef INCLUDE_SOX
m.def(
"get_info_fileobj",
&torchaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
m.def(
"save_audio_fileobj",
&torchaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
m.def(
"apply_effects_fileobj",
&torchaudio::sox_effects::apply_effects_fileobj,
"Decode audio data from file-like obj and apply effects.");
#endif
}
#include <torchaudio/csrc/pybind/sox/effects.h>
#include <torchaudio/csrc/pybind/sox/effects_chain.h>
#include <torchaudio/csrc/pybind/sox/utils.h>
using namespace torchaudio::sox_utils;
namespace torchaudio::sox_effects {
// Streaming decoding over file-like object is tricky because libsox operates on
// FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses fmemopen,
// which returns FILE* which points the buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// For Step 2. see `fileobj_input_drain` function in effects_chain.cpp
auto apply_effects_fileobj(
py::object fileobj,
const std::vector<std::vector<std::string>>& effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) -> std::tuple<torch::Tensor, int64_t> {
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
//
// For certain format (such as FLAC), libsox keeps reading the content at
// the initialization unless it reaches EOF even when the header is properly
// parsed. (Making buffer size 8192, which is way bigger than the header,
// resulted in libsox consuming all the buffer content at the time it opens
// the file.) Therefore buffer has to always contain valid data, except after
// EOF. We default to `sox_get_globals()->bufsiz`* for buffer size and we
// first check if there is enough data to fill the buffer. `read_fileobj`
// repeatedly calls `read` method until it receives the requested length of
// bytes or it reaches EOF. If we get bytes shorter than requested, that means
// the whole audio data are fetched.
//
// * This can be changed with `torchaudio.utils.sox_utils.set_buffer_size`.
const auto capacity = [&]() {
// NOTE:
// Use the abstraction provided by `libtorchaudio` to access the global
// config defined by libsox. Directly using `sox_get_globals` function will
// end up retrieving the static variable defined in `_torchaudio`, which is
// not correct.
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 256;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto* in_buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, in_buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto in_buffer_size = (num_read > 256) ? num_read : 256;
// Open file (this starts reading the header)
// When opening a file there are two functions that can touches FILE*.
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChainPyBind chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::make_tuple(
tensor, static_cast<int64_t>(chain.getOutputSampleRate()));
}
} // namespace torchaudio::sox_effects
#ifndef TORCHAUDIO_PYBIND_SOX_EFFECTS_H
#define TORCHAUDIO_PYBIND_SOX_EFFECTS_H
#include <torch/extension.h>
namespace torchaudio::sox_effects {
auto apply_effects_fileobj(
py::object fileobj,
const std::vector<std::vector<std::string>>& effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) -> std::tuple<torch::Tensor, int64_t>;
} // namespace torchaudio::sox_effects
#endif
#include <sox.h>
#include <torchaudio/csrc/pybind/sox/effects_chain.h>
#include <torchaudio/csrc/pybind/sox/utils.h>
using namespace torchaudio::sox_utils;
namespace torchaudio::sox_effects_chain {
namespace {
/// helper classes for passing file-like object to SoxEffectChain
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
bool eof_reached;
char* buffer;
uint64_t buffer_size;
};
struct FileObjOutputPriv {
sox_format_t* sf;
py::object* fileobj;
char** buffer;
size_t* buffer_size;
};
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
auto fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp)
-> int {
auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf;
auto buffer = priv->buffer;
// 1. Refresh the buffer
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't
// help) Therefore we need to align the content at the end of buffer,
// otherwise, libsox will keep reading the content beyond intended length.
//
// Before:
//
// |<-------consumed------>|<---remaining--->|
// |***********************|-----------------|
// ^ ftell
//
// After:
//
// |<-offset->|<---remaining--->|<-new data->|
// |**********|-----------------|++++++++++++|
// ^ ftell
// NOTE:
// Do not use `sf->tell_off` here. Presumably, `tell_off` and `fseek` are
// supposed to be in sync, but there are cases (Vorbis) they are not
// in sync and `tell_off` has seemingly uninitialized value, which
// leads num_remain to be negative and cause segmentation fault
// in `memmove`.
const auto tell = ftell((FILE*)sf->fp);
if (tell < 0) {
throw std::runtime_error("Internal Error: ftell failed.");
}
const auto num_consumed = static_cast<size_t>(tell);
if (num_consumed > priv->buffer_size) {
throw std::runtime_error("Internal Error: buffer overrun.");
}
const auto num_remain = priv->buffer_size - num_consumed;
// 1.1. Fetch the data to see if there is data to fill the buffer
size_t num_refill = 0;
std::string chunk(num_consumed, '\0');
if (num_consumed && !priv->eof_reached) {
num_refill = read_fileobj(
priv->fileobj, num_consumed, const_cast<char*>(chunk.data()));
if (num_refill < num_consumed) {
priv->eof_reached = true;
}
}
const auto offset = num_consumed - num_refill;
// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
auto src = static_cast<void*>(buffer + num_consumed);
auto dst = static_cast<void*>(buffer + offset);
memmove(dst, src, num_remain);
}
// 1.3. Refill the remaining buffer.
if (num_refill) {
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
}
// 1.4. Set the file pointer to the new offset
sf->tell_off = offset;
fseek((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % effp->out_signal.channels;
// Read up to *osamp samples into obuf;
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);
// Decoding is finished when fileobject is exhausted and sox can no longer
// decode a sample.
return (priv->eof_reached && !*osamp) ? SOX_EOF : SOX_SUCCESS;
}
auto fileobj_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) -> int {
*osamp = 0;
if (*isamp) {
auto priv = static_cast<FileObjOutputPriv*>(effp->priv);
auto sf = priv->sf;
auto fp = static_cast<FILE*>(sf->fp);
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
// Encode chunk
auto num_samples_written = sox_write(sf, ibuf, *isamp);
fflush(fp);
// Copy the encoded chunk to python object.
fileobj->attr("write")(py::bytes(*buffer, ftell(fp)));
// Reset FILE*
sf->tell_off = 0;
fseek(fp, 0, SEEK_SET);
if (num_samples_written != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
auto get_fileobj_input_handler() -> sox_effect_handler_t* {
static sox_effect_handler_t handler{
/*name=*/"input_fileobj_object",
/*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/nullptr,
/*start=*/nullptr,
/*flow=*/nullptr,
/*drain=*/fileobj_input_drain,
/*stop=*/nullptr,
/*kill=*/nullptr,
/*priv_size=*/sizeof(FileObjInputPriv)};
return &handler;
}
auto get_fileobj_output_handler() -> sox_effect_handler_t* {
static sox_effect_handler_t handler{
/*name=*/"output_fileobj_object",
/*usage=*/nullptr,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/nullptr,
/*start=*/nullptr,
/*flow=*/fileobj_output_flow,
/*drain=*/nullptr,
/*stop=*/nullptr,
/*kill=*/nullptr,
/*priv_size=*/sizeof(FileObjOutputPriv)};
return &handler;
}
} // namespace
void SoxEffectsChainPyBind::addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_fileobj_input_handler()));
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->eof_reached = false;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: input fileobj");
}
}
void SoxEffectsChainPyBind::addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_fileobj_output_handler()));
auto priv = static_cast<FileObjOutputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: output fileobj");
}
}
} // namespace torchaudio::sox_effects_chain
#ifndef TORCHAUDIO_PYBIND_SOX_EFFECTS_CHAIN_H
#define TORCHAUDIO_PYBIND_SOX_EFFECTS_CHAIN_H
#include <torch/extension.h>
#include <torchaudio/csrc/sox/effects_chain.h>
namespace torchaudio::sox_effects_chain {
class SoxEffectsChainPyBind : public SoxEffectsChain {
using SoxEffectsChain::SoxEffectsChain;
public:
void addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj);
void addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj);
};
} // namespace torchaudio::sox_effects_chain
#endif
#include <torchaudio/csrc/pybind/sox/effects.h>
#include <torchaudio/csrc/pybind/sox/effects_chain.h>
#include <torchaudio/csrc/pybind/sox/io.h>
#include <torchaudio/csrc/pybind/sox/utils.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/types.h>
#include <utility>
using namespace torchaudio::sox_utils;
namespace torchaudio::sox_io {
auto get_info_fileobj(py::object fileobj, c10::optional<std::string> format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> {
// Prepare in-memory file object
// When libsox opens a file, it also reads the header.
// When opening a file there are two functions that might touch FILE* (and the
// underlying buffer).
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
//
// `auto_detect_format` function only requires 256 bytes, but format-dependent
// `startread` handler might require more data. In case of vorbis, the size of
// header is unbounded, but typically 4kB maximum.
//
// "The header size is unbounded, although for streaming a rule-of-thumb of
// 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this
// suggestion)."
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
const auto capacity = [&]() {
// NOTE:
// Use the abstraction provided by `libtorchaudio` to access the global
// config defined by libsox. Directly using `sox_get_globals` function will
// end up retrieving the static variable defined in `_torchaudio`, which is
// not correct.
const auto bufsiz = get_buffer_size();
const int64_t kDefaultCapacityInBytes = 4096;
return (bufsiz > kDefaultCapacityInBytes) ? bufsiz
: kDefaultCapacityInBytes;
}();
std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
auto load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) -> std::tuple<torch::Tensor, int64_t> {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_fileobj(
std::move(fileobj),
effects,
normalize,
channels_first,
std::move(format));
}
namespace {
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct AutoReleaseBuffer {
char* ptr;
size_t size;
AutoReleaseBuffer() : ptr(nullptr), size(0) {}
AutoReleaseBuffer(const AutoReleaseBuffer& other) = delete;
AutoReleaseBuffer(AutoReleaseBuffer&& other) = delete;
auto operator=(const AutoReleaseBuffer& other) -> AutoReleaseBuffer& = delete;
auto operator=(AutoReleaseBuffer&& other) -> AutoReleaseBuffer& = delete;
~AutoReleaseBuffer() {
if (ptr) {
free(ptr);
}
}
};
} // namespace
void save_audio_fileobj(
py::object fileobj,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
if (!format.has_value()) {
throw std::runtime_error(
"`format` is required when saving to file object.");
}
const auto filetype = format.value();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"amr-nb format only supports single channel audio.");
}
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"htk format only supports single channel audio.");
}
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"gsm format only supports single channel audio.");
}
if (sample_rate != 8000) {
throw std::runtime_error(
"gsm format only supports a sampling rate of 8kHz.");
}
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype,
tensor.dtype(),
compression,
std::move(encoding),
bits_per_sample);
AutoReleaseBuffer buffer;
SoxFormat sf(sox_open_memstream_write(
&buffer.ptr,
&buffer.size,
&signal_info,
&encoding_info,
filetype.c_str(),
/*oob=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open memory stream.");
}
torchaudio::sox_effects_chain::SoxEffectsChainPyBind chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run();
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf.close();
fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size));
}
} // namespace torchaudio::sox_io
#ifndef TORCHAUDIO_PYBIND_SOX_IO_H
#define TORCHAUDIO_PYBIND_SOX_IO_H
#include <torch/extension.h>
namespace torchaudio::sox_io {
auto get_info_fileobj(py::object fileobj, c10::optional<std::string> format)
-> std::tuple<int64_t, int64_t, int64_t, int64_t, std::string>;
auto load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) -> std::tuple<torch::Tensor, int64_t>;
void save_audio_fileobj(
py::object fileobj,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample);
} // namespace torchaudio::sox_io
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment