Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
# flake8: noqa # flake8: noqa
from . import utils from . import utils
from .utils import ( from .utils import (
list_audio_backends, list_audio_backends,
......
...@@ -26,13 +26,14 @@ class AudioMetaData: ...@@ -26,13 +26,14 @@ class AudioMetaData:
* ``HTK``: Single channel 16-bit PCM * ``HTK``: Single channel 16-bit PCM
* ``UNKNOWN`` : None of above * ``UNKNOWN`` : None of above
""" """
def __init__( def __init__(
self, self,
sample_rate: int, sample_rate: int,
num_frames: int, num_frames: int,
num_channels: int, num_channels: int,
bits_per_sample: int, bits_per_sample: int,
encoding: str, encoding: str,
): ):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_frames = num_frames self.num_frames = num_frames
......
...@@ -4,19 +4,21 @@ from typing import Callable, Optional, Tuple, Union ...@@ -4,19 +4,21 @@ from typing import Callable, Optional, Tuple, Union
from torch import Tensor from torch import Tensor
def load(filepath: Union[str, Path], def load(
out: Optional[Tensor] = None, filepath: Union[str, Path],
normalization: Union[bool, float, Callable] = True, out: Optional[Tensor] = None,
channels_first: bool = True, normalization: Union[bool, float, Callable] = True,
num_frames: int = 0, channels_first: bool = True,
offset: int = 0, num_frames: int = 0,
filetype: Optional[str] = None) -> Tuple[Tensor, int]: offset: int = 0,
raise RuntimeError('No audio I/O backend is available.') filetype: Optional[str] = None,
) -> Tuple[Tensor, int]:
raise RuntimeError("No audio I/O backend is available.")
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None: def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
raise RuntimeError('No audio I/O backend is available.') raise RuntimeError("No audio I/O backend is available.")
def info(filepath: str) -> None: def info(filepath: str) -> None:
raise RuntimeError('No audio I/O backend is available.') raise RuntimeError("No audio I/O backend is available.")
"""The new soundfile backend which will become default in 0.8.0 onward""" """The new soundfile backend which will become default in 0.8.0 onward"""
from typing import Tuple, Optional
import warnings import warnings
from typing import Tuple, Optional
import torch import torch
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData from .common import AudioMetaData
...@@ -19,33 +20,33 @@ if _mod_utils.is_soundfile_available(): ...@@ -19,33 +20,33 @@ if _mod_utils.is_soundfile_available():
# The dict is inspired from # The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94 # https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = { _SUBTYPE_TO_BITS_PER_SAMPLE = {
'PCM_S8': 8, # Signed 8 bit data "PCM_S8": 8, # Signed 8 bit data
'PCM_16': 16, # Signed 16 bit data "PCM_16": 16, # Signed 16 bit data
'PCM_24': 24, # Signed 24 bit data "PCM_24": 24, # Signed 24 bit data
'PCM_32': 32, # Signed 32 bit data "PCM_32": 32, # Signed 32 bit data
'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only) "PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
'FLOAT': 32, # 32 bit float data "FLOAT": 32, # 32 bit float data
'DOUBLE': 64, # 64 bit float data "DOUBLE": 64, # 64 bit float data
'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types "ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types "ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'IMA_ADPCM': 0, # IMA ADPCM. "IMA_ADPCM": 0, # IMA ADPCM.
'MS_ADPCM': 0, # Microsoft ADPCM. "MS_ADPCM": 0, # Microsoft ADPCM.
'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate) "GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
'VOX_ADPCM': 0, # OKI / Dialogix ADPCM "VOX_ADPCM": 0, # OKI / Dialogix ADPCM
'G721_32': 0, # 32kbs G721 ADPCM encoding. "G721_32": 0, # 32kbs G721 ADPCM encoding.
'G723_24': 0, # 24kbs G723 ADPCM encoding. "G723_24": 0, # 24kbs G723 ADPCM encoding.
'G723_40': 0, # 40kbs G723 ADPCM encoding. "G723_40": 0, # 40kbs G723 ADPCM encoding.
'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding. "DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding. "DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding. "DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
'DWVW_N': 0, # N bit Delta Width Variable Word encoding. "DWVW_N": 0, # N bit Delta Width Variable Word encoding.
'DPCM_8': 8, # 8 bit differential PCM (XI only) "DPCM_8": 8, # 8 bit differential PCM (XI only)
'DPCM_16': 16, # 16 bit differential PCM (XI only) "DPCM_16": 16, # 16 bit differential PCM (XI only)
'VORBIS': 0, # Xiph Vorbis encoding. (lossy) "VORBIS": 0, # Xiph Vorbis encoding. (lossy)
'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit). "ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit). "ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit). "ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit). "ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
} }
...@@ -61,23 +62,23 @@ def _get_bit_depth(subtype): ...@@ -61,23 +62,23 @@ def _get_bit_depth(subtype):
_SUBTYPE_TO_ENCODING = { _SUBTYPE_TO_ENCODING = {
'PCM_S8': 'PCM_S', "PCM_S8": "PCM_S",
'PCM_16': 'PCM_S', "PCM_16": "PCM_S",
'PCM_24': 'PCM_S', "PCM_24": "PCM_S",
'PCM_32': 'PCM_S', "PCM_32": "PCM_S",
'PCM_U8': 'PCM_U', "PCM_U8": "PCM_U",
'FLOAT': 'PCM_F', "FLOAT": "PCM_F",
'DOUBLE': 'PCM_F', "DOUBLE": "PCM_F",
'ULAW': 'ULAW', "ULAW": "ULAW",
'ALAW': 'ALAW', "ALAW": "ALAW",
'VORBIS': 'VORBIS', "VORBIS": "VORBIS",
} }
def _get_encoding(format: str, subtype: str): def _get_encoding(format: str, subtype: str):
if format == 'FLAC': if format == "FLAC":
return 'FLAC' return "FLAC"
return _SUBTYPE_TO_ENCODING.get(subtype, 'UNKNOWN') return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
@_mod_utils.requires_soundfile() @_mod_utils.requires_soundfile()
...@@ -211,10 +212,7 @@ def load( ...@@ -211,10 +212,7 @@ def load(
return waveform, sample_rate return waveform, sample_rate
def _get_subtype_for_wav( def _get_subtype_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int):
dtype: torch.dtype,
encoding: str,
bits_per_sample: int):
if not encoding: if not encoding:
if not bits_per_sample: if not bits_per_sample:
subtype = { subtype = {
...@@ -271,11 +269,7 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int): ...@@ -271,11 +269,7 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
raise ValueError(f"sph does not support {encoding}.") raise ValueError(f"sph does not support {encoding}.")
def _get_subtype( def _get_subtype(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int):
dtype: torch.dtype,
format: str,
encoding: str,
bits_per_sample: int):
if format == "wav": if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample) return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac": if format == "flac":
...@@ -288,8 +282,7 @@ def _get_subtype( ...@@ -288,8 +282,7 @@ def _get_subtype(
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}" return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"): if format in ("ogg", "vorbis"):
if encoding or bits_per_sample: if encoding or bits_per_sample:
raise ValueError( raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
"ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS" return "VORBIS"
if format == "sph": if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample) return _get_subtype_for_sphere(encoding, bits_per_sample)
...@@ -407,9 +400,9 @@ def save( ...@@ -407,9 +400,9 @@ def save(
'`save` function of "soundfile" backend does not support "compression" parameter. ' '`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored." "The argument is silently ignored."
) )
if hasattr(filepath, 'write'): if hasattr(filepath, "write"):
if format is None: if format is None:
raise RuntimeError('`format` is required when saving to file object.') raise RuntimeError("`format` is required when saving to file object.")
ext = format.lower() ext = format.lower()
else: else:
ext = str(filepath).split(".")[-1].lower() ext = str(filepath).split(".")[-1].lower()
...@@ -417,8 +410,10 @@ def save( ...@@ -417,8 +410,10 @@ def save(
if bits_per_sample not in (None, 8, 16, 24, 32, 64): if bits_per_sample not in (None, 8, 16, 24, 32, 64):
raise ValueError("Invalid bits_per_sample.") raise ValueError("Invalid bits_per_sample.")
if bits_per_sample == 24: if bits_per_sample == 24:
warnings.warn("Saving audio with 24 bits per sample might warp samples near -1. " warnings.warn(
"Using 16 bits per sample might be able to avoid this.") "Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this."
)
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample) subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
...@@ -429,6 +424,4 @@ def save( ...@@ -429,6 +424,4 @@ def save(
if channels_first: if channels_first:
src = src.t() src = src.t()
soundfile.write( soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
)
...@@ -2,18 +2,18 @@ import os ...@@ -2,18 +2,18 @@ import os
from typing import Tuple, Optional from typing import Tuple, Optional
import torch import torch
import torchaudio
from torchaudio._internal import ( from torchaudio._internal import (
module_utils as _mod_utils, module_utils as _mod_utils,
) )
import torchaudio
from .common import AudioMetaData from .common import AudioMetaData
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def info( def info(
filepath: str, filepath: str,
format: Optional[str] = None, format: Optional[str] = None,
) -> AudioMetaData: ) -> AudioMetaData:
"""Get signal information of an audio file. """Get signal information of an audio file.
...@@ -46,7 +46,7 @@ def info( ...@@ -46,7 +46,7 @@ def info(
AudioMetaData: Metadata of the given audio. AudioMetaData: Metadata of the given audio.
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, 'read'): if hasattr(filepath, "read"):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format) sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
...@@ -56,12 +56,12 @@ def info( ...@@ -56,12 +56,12 @@ def info(
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def load( def load(
filepath: str, filepath: str,
frame_offset: int = 0, frame_offset: int = 0,
num_frames: int = -1, num_frames: int = -1,
normalize: bool = True, normalize: bool = True,
channels_first: bool = True, channels_first: bool = True,
format: Optional[str] = None, format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
"""Load audio data from file. """Load audio data from file.
...@@ -145,24 +145,26 @@ def load( ...@@ -145,24 +145,26 @@ def load(
`[channel, time]` else `[time, channel]`. `[channel, time]` else `[time, channel]`.
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, 'read'): if hasattr(filepath, "read"):
return torchaudio._torchaudio.load_audio_fileobj( return torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format
)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
return torch.ops.torchaudio.sox_io_load_audio_file( return torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format
)
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def save( def save(
filepath: str, filepath: str,
src: torch.Tensor, src: torch.Tensor,
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
format: Optional[str] = None, format: Optional[str] = None,
encoding: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None, bits_per_sample: Optional[int] = None,
): ):
"""Save audio data to file. """Save audio data to file.
...@@ -309,11 +311,12 @@ def save( ...@@ -309,11 +311,12 @@ def save(
or ``libmp3lame`` etc. or ``libmp3lame`` etc.
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, 'write'): if hasattr(filepath, "write"):
torchaudio._torchaudio.save_audio_fileobj( torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample
format, encoding, bits_per_sample) )
return return
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file( torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample) filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
...@@ -4,6 +4,7 @@ from typing import Optional, List ...@@ -4,6 +4,7 @@ from typing import Optional, List
import torchaudio import torchaudio
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from . import ( from . import (
no_backend, no_backend,
sox_io_backend, sox_io_backend,
...@@ -11,9 +12,9 @@ from . import ( ...@@ -11,9 +12,9 @@ from . import (
) )
__all__ = [ __all__ = [
'list_audio_backends', "list_audio_backends",
'get_audio_backend', "get_audio_backend",
'set_audio_backend', "set_audio_backend",
] ]
...@@ -24,10 +25,10 @@ def list_audio_backends() -> List[str]: ...@@ -24,10 +25,10 @@ def list_audio_backends() -> List[str]:
List[str]: The list of available backends. List[str]: The list of available backends.
""" """
backends = [] backends = []
if _mod_utils.is_module_available('soundfile'): if _mod_utils.is_module_available("soundfile"):
backends.append('soundfile') backends.append("soundfile")
if _mod_utils.is_sox_available(): if _mod_utils.is_sox_available():
backends.append('sox_io') backends.append("sox_io")
return backends return backends
...@@ -40,31 +41,29 @@ def set_audio_backend(backend: Optional[str]): ...@@ -40,31 +41,29 @@ def set_audio_backend(backend: Optional[str]):
of the system. If ``None`` is provided the current backend is unassigned. of the system. If ``None`` is provided the current backend is unassigned.
""" """
if backend is not None and backend not in list_audio_backends(): if backend is not None and backend not in list_audio_backends():
raise RuntimeError( raise RuntimeError(f'Backend "{backend}" is not one of ' f"available backends: {list_audio_backends()}.")
f'Backend "{backend}" is not one of '
f'available backends: {list_audio_backends()}.')
if backend is None: if backend is None:
module = no_backend module = no_backend
elif backend == 'sox_io': elif backend == "sox_io":
module = sox_io_backend module = sox_io_backend
elif backend == 'soundfile': elif backend == "soundfile":
module = soundfile_backend module = soundfile_backend
else: else:
raise NotImplementedError(f'Unexpected backend "{backend}"') raise NotImplementedError(f'Unexpected backend "{backend}"')
for func in ['save', 'load', 'info']: for func in ["save", "load", "info"]:
setattr(torchaudio, func, getattr(module, func)) setattr(torchaudio, func, getattr(module, func))
def _init_audio_backend(): def _init_audio_backend():
backends = list_audio_backends() backends = list_audio_backends()
if 'sox_io' in backends: if "sox_io" in backends:
set_audio_backend('sox_io') set_audio_backend("sox_io")
elif 'soundfile' in backends: elif "soundfile" in backends:
set_audio_backend('soundfile') set_audio_backend("soundfile")
else: else:
warnings.warn('No audio backend is available.') warnings.warn("No audio backend is available.")
set_audio_backend(None) set_audio_backend(None)
...@@ -77,7 +76,7 @@ def get_audio_backend() -> Optional[str]: ...@@ -77,7 +76,7 @@ def get_audio_backend() -> Optional[str]:
if torchaudio.load == no_backend.load: if torchaudio.load == no_backend.load:
return None return None
if torchaudio.load == sox_io_backend.load: if torchaudio.load == sox_io_backend.load:
return 'sox_io' return "sox_io"
if torchaudio.load == soundfile_backend.load: if torchaudio.load == soundfile_backend.load:
return 'soundfile' return "soundfile"
raise ValueError('Unknown backend.') raise ValueError("Unknown backend.")
from . import kaldi from . import kaldi
__all__ = [ __all__ = [
'kaldi', "kaldi",
] ]
import math
from typing import Tuple from typing import Tuple
import math
import torch import torch
from torch import Tensor
import torchaudio import torchaudio
from torch import Tensor
__all__ = [ __all__ = [
'get_mel_banks', "get_mel_banks",
'inverse_mel_scale', "inverse_mel_scale",
'inverse_mel_scale_scalar', "inverse_mel_scale_scalar",
'mel_scale', "mel_scale",
'mel_scale_scalar', "mel_scale_scalar",
'spectrogram', "spectrogram",
'fbank', "fbank",
'mfcc', "mfcc",
'vtln_warp_freq', "vtln_warp_freq",
'vtln_warp_mel_freq', "vtln_warp_mel_freq",
] ]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07 # numeric_limits<float>::epsilon() 1.1920928955078125e-07
...@@ -25,11 +24,11 @@ EPSILON = torch.tensor(torch.finfo(torch.float).eps) ...@@ -25,11 +24,11 @@ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
MILLISECONDS_TO_SECONDS = 0.001 MILLISECONDS_TO_SECONDS = 0.001
# window types # window types
HAMMING = 'hamming' HAMMING = "hamming"
HANNING = 'hanning' HANNING = "hanning"
POVEY = 'povey' POVEY = "povey"
RECTANGULAR = 'rectangular' RECTANGULAR = "rectangular"
BLACKMAN = 'blackman' BLACKMAN = "blackman"
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
...@@ -38,8 +37,7 @@ def _get_epsilon(device, dtype): ...@@ -38,8 +37,7 @@ def _get_epsilon(device, dtype):
def _next_power_of_2(x: int) -> int: def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x r"""Returns the smallest power of 2 that is greater than x"""
"""
return 1 if x == 0 else 2 ** (x - 1).bit_length() return 1 if x == 0 else 2 ** (x - 1).bit_length()
...@@ -85,14 +83,14 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg ...@@ -85,14 +83,14 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg
return waveform.as_strided(sizes, strides) return waveform.as_strided(sizes, strides)
def _feature_window_function(window_type: str, def _feature_window_function(
window_size: int, window_type: str,
blackman_coeff: float, window_size: int,
device: torch.device, blackman_coeff: float,
dtype: int, device: torch.device,
) -> Tensor: dtype: int,
r"""Returns a window function with the given type and size ) -> Tensor:
""" r"""Returns a window function with the given type and size"""
if window_type == HANNING: if window_type == HANNING:
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING: elif window_type == HAMMING:
...@@ -106,64 +104,67 @@ def _feature_window_function(window_type: str, ...@@ -106,64 +104,67 @@ def _feature_window_function(window_type: str,
a = 2 * math.pi / (window_size - 1) a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, device=device, dtype=dtype) window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients # can't use torch.blackman_window as they use different coefficients
return (blackman_coeff - 0.5 * torch.cos(a * window_function) + return (
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype) blackman_coeff
- 0.5 * torch.cos(a * window_function)
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
).to(device=device, dtype=dtype)
else: else:
raise Exception('Invalid window type ' + window_type) raise Exception("Invalid window type " + window_type)
def _get_log_energy(strided_input: Tensor, def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
epsilon: Tensor, r"""Returns the log energy of size (m) for a strided_input (m,*)"""
energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)
"""
device, dtype = strided_input.device, strided_input.dtype device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0: if energy_floor == 0.0:
return log_energy return log_energy
return torch.max( return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(
def _get_waveform_and_window_properties(waveform: Tensor, waveform: Tensor,
channel: int, channel: int,
sample_frequency: float, sample_frequency: float,
frame_shift: float, frame_shift: float,
frame_length: float, frame_length: float,
round_to_power_of_two: bool, round_to_power_of_two: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]: preemphasis_coefficient: float,
r"""Gets the waveform and window properties ) -> Tuple[Tensor, int, int, int]:
""" r"""Gets the waveform and window properties"""
channel = max(channel, 0) channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0))) assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
waveform = waveform[channel, :] # size (n) waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len( assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
waveform), ('choose a window size {} that is [2, {}]' window_size, len(waveform)
.format(window_size, len(waveform))) )
assert 0 < window_shift, '`window_shift` must be greater than 0' assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \ assert padded_window_size % 2 == 0, (
' use `round_to_power_of_two` or change `frame_length`' "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]' )
assert sample_frequency > 0, '`sample_frequency` must be greater than zero' assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
return waveform, window_shift, window_size, padded_window_size return waveform, window_shift, window_size, padded_window_size
def _get_window(waveform: Tensor, def _get_window(
padded_window_size: int, waveform: Tensor,
window_size: int, padded_window_size: int,
window_shift: int, window_size: int,
window_type: str, window_shift: int,
blackman_coeff: float, window_type: str,
snip_edges: bool, blackman_coeff: float,
raw_energy: bool, snip_edges: bool,
energy_floor: float, raw_energy: bool,
dither: float, energy_floor: float,
remove_dc_offset: bool, dither: float,
preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]: remove_dc_offset: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
r"""Gets a window and its log energy r"""Gets a window and its log energy
Returns: Returns:
...@@ -193,20 +194,23 @@ def _get_window(waveform: Tensor, ...@@ -193,20 +194,23 @@ def _get_window(waveform: Tensor,
if preemphasis_coefficient != 0.0: if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad( offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
strided_input.unsqueeze(0), (1, 0), mode='replicate').squeeze(0) # size (m, window_size + 1) 0
) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame # Apply window_function to each row/frame
window_function = _feature_window_function( window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size) 0
) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size) strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size) # Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size: if padded_window_size != window_size:
padding_right = padded_window_size - window_size padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad( strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode='constant', value=0).squeeze(0) strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
).squeeze(0)
# Compute energy after window function (not the raw one) # Compute energy after window function (not the raw one)
if not raw_energy: if not raw_energy:
...@@ -224,22 +228,24 @@ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: ...@@ -224,22 +228,24 @@ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
return tensor return tensor
def spectrogram(waveform: Tensor, def spectrogram(
blackman_coeff: float = 0.42, waveform: Tensor,
channel: int = -1, blackman_coeff: float = 0.42,
dither: float = 0.0, channel: int = -1,
energy_floor: float = 1.0, dither: float = 0.0,
frame_length: float = 25.0, energy_floor: float = 1.0,
frame_shift: float = 10.0, frame_length: float = 25.0,
min_duration: float = 0.0, frame_shift: float = 10.0,
preemphasis_coefficient: float = 0.97, min_duration: float = 0.0,
raw_energy: bool = True, preemphasis_coefficient: float = 0.97,
remove_dc_offset: bool = True, raw_energy: bool = True,
round_to_power_of_two: bool = True, remove_dc_offset: bool = True,
sample_frequency: float = 16000.0, round_to_power_of_two: bool = True,
snip_edges: bool = True, sample_frequency: float = 16000.0,
subtract_mean: bool = False, snip_edges: bool = True,
window_type: str = POVEY) -> Tensor: subtract_mean: bool = False,
window_type: str = POVEY,
) -> Tensor:
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats. compute-spectrogram-feats.
...@@ -278,21 +284,33 @@ def spectrogram(waveform: Tensor, ...@@ -278,21 +284,33 @@ def spectrogram(waveform: Tensor,
epsilon = _get_epsilon(device, dtype) epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency: if len(waveform) < min_duration * sample_frequency:
# signal is too short # signal is too short
return torch.empty(0) return torch.empty(0)
strided_input, signal_log_energy = _get_window( strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, waveform,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1, 2)
fft = torch.fft.rfft(strided_input) fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum # Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
...@@ -315,12 +333,14 @@ def mel_scale(freq: Tensor) -> Tensor: ...@@ -315,12 +333,14 @@ def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log() return 1127.0 * (1.0 + freq / 700.0).log()
def vtln_warp_freq(vtln_low_cutoff: float, def vtln_warp_freq(
vtln_high_cutoff: float, vtln_low_cutoff: float,
low_freq: float, vtln_high_cutoff: float,
high_freq: float, low_freq: float,
vtln_warp_factor: float, high_freq: float,
freq: Tensor) -> Tensor: vtln_warp_factor: float,
freq: Tensor,
) -> Tensor:
r"""This computes a VTLN warping function that is not the same as HTK's one, r"""This computes a VTLN warping function that is not the same as HTK's one,
but has similar inputs (this function has the advantage of never producing but has similar inputs (this function has the advantage of never producing
empty bins). empty bins).
...@@ -357,8 +377,8 @@ def vtln_warp_freq(vtln_low_cutoff: float, ...@@ -357,8 +377,8 @@ def vtln_warp_freq(vtln_low_cutoff: float,
Returns: Returns:
Tensor: Freq after vtln warp Tensor: Freq after vtln warp
""" """
assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq' assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]' assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
l = vtln_low_cutoff * max(1.0, vtln_warp_factor) l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor) h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor scale = 1.0 / vtln_warp_factor
...@@ -388,11 +408,14 @@ def vtln_warp_freq(vtln_low_cutoff: float, ...@@ -388,11 +408,14 @@ def vtln_warp_freq(vtln_low_cutoff: float,
return res return res
def vtln_warp_mel_freq(vtln_low_cutoff: float, def vtln_warp_mel_freq(
vtln_high_cutoff: float, vtln_low_cutoff: float,
low_freq, high_freq: float, vtln_high_cutoff: float,
vtln_warp_factor: float, low_freq,
mel_freq: Tensor) -> Tensor: high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor,
) -> Tensor:
r""" r"""
Args: Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
...@@ -405,25 +428,30 @@ def vtln_warp_mel_freq(vtln_low_cutoff: float, ...@@ -405,25 +428,30 @@ def vtln_warp_mel_freq(vtln_low_cutoff: float,
Returns: Returns:
Tensor: ``mel_freq`` after vtln warp Tensor: ``mel_freq`` after vtln warp
""" """
return mel_scale(vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, return mel_scale(
vtln_warp_factor, inverse_mel_scale(mel_freq))) vtln_warp_freq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
)
def get_mel_banks(num_bins: int, )
window_length_padded: int,
sample_freq: float,
low_freq: float, def get_mel_banks(
high_freq: float, num_bins: int,
vtln_low: float, window_length_padded: int,
vtln_high: float, sample_freq: float,
vtln_warp_factor: float) -> Tuple[Tensor, Tensor]: low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float,
) -> Tuple[Tensor, Tensor]:
""" """
Returns: Returns:
(Tensor, Tensor): The tuple consists of ``bins`` (which is (Tensor, Tensor): The tuple consists of ``bins`` (which is
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
center frequencies of bins of size (``num_bins``)). center frequencies of bins of size (``num_bins``)).
""" """
assert num_bins > 3, 'Must have at least 3 mel bins' assert num_bins > 3, "Must have at least 3 mel bins"
assert window_length_padded % 2 == 0 assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2 num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq nyquist = 0.5 * sample_freq
...@@ -431,8 +459,9 @@ def get_mel_banks(num_bins: int, ...@@ -431,8 +459,9 @@ def get_mel_banks(num_bins: int,
if high_freq <= 0.0: if high_freq <= 0.0:
high_freq += nyquist high_freq += nyquist
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \ assert (
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist)) (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
# fft-bin width [think of it as Nyquist-freq / half-window-length] # fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded fft_bin_width = sample_freq / window_length_padded
...@@ -446,10 +475,11 @@ def get_mel_banks(num_bins: int, ...@@ -446,10 +475,11 @@ def get_mel_banks(num_bins: int,
if vtln_high < 0.0: if vtln_high < 0.0:
vtln_high += nyquist vtln_high += nyquist
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and assert vtln_warp_factor == 1.0 or (
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
('Bad values in options: vtln-low {} and vtln-high {}, versus ' ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) vtln_low, vtln_high, low_freq, high_freq
)
bin = torch.arange(num_bins).unsqueeze(1) bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
...@@ -483,32 +513,34 @@ def get_mel_banks(num_bins: int, ...@@ -483,32 +513,34 @@ def get_mel_banks(num_bins: int,
return bins, center_freqs return bins, center_freqs
def fbank(waveform: Tensor, def fbank(
blackman_coeff: float = 0.42, waveform: Tensor,
channel: int = -1, blackman_coeff: float = 0.42,
dither: float = 0.0, channel: int = -1,
energy_floor: float = 1.0, dither: float = 0.0,
frame_length: float = 25.0, energy_floor: float = 1.0,
frame_shift: float = 10.0, frame_length: float = 25.0,
high_freq: float = 0.0, frame_shift: float = 10.0,
htk_compat: bool = False, high_freq: float = 0.0,
low_freq: float = 20.0, htk_compat: bool = False,
min_duration: float = 0.0, low_freq: float = 20.0,
num_mel_bins: int = 23, min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97, num_mel_bins: int = 23,
raw_energy: bool = True, preemphasis_coefficient: float = 0.97,
remove_dc_offset: bool = True, raw_energy: bool = True,
round_to_power_of_two: bool = True, remove_dc_offset: bool = True,
sample_frequency: float = 16000.0, round_to_power_of_two: bool = True,
snip_edges: bool = True, sample_frequency: float = 16000.0,
subtract_mean: bool = False, snip_edges: bool = True,
use_energy: bool = False, subtract_mean: bool = False,
use_log_fbank: bool = True, use_energy: bool = False,
use_power: bool = True, use_log_fbank: bool = True,
vtln_high: float = -500.0, use_power: bool = True,
vtln_low: float = 100.0, vtln_high: float = -500.0,
vtln_warp: float = 1.0, vtln_low: float = 100.0,
window_type: str = POVEY) -> Tensor: vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
compute-fbank-feats. compute-fbank-feats.
...@@ -559,7 +591,8 @@ def fbank(waveform: Tensor, ...@@ -559,7 +591,8 @@ def fbank(waveform: Tensor,
device, dtype = waveform.device, waveform.dtype device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency: if len(waveform) < min_duration * sample_frequency:
# signal is too short # signal is too short
...@@ -567,21 +600,33 @@ def fbank(waveform: Tensor, ...@@ -567,21 +600,33 @@ def fbank(waveform: Tensor,
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m) # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window( strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, waveform,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1) # size (m, padded_window_size // 2 + 1)
spectrum = torch.fft.rfft(strided_input).abs() spectrum = torch.fft.rfft(strided_input).abs()
if use_power: if use_power:
spectrum = spectrum.pow(2.) spectrum = spectrum.pow(2.0)
# size (num_mel_bins, padded_window_size // 2) # size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, mel_energies, _ = get_mel_banks(
low_freq, high_freq, vtln_low, vtln_high, vtln_warp) num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
)
mel_energies = mel_energies.to(device=device, dtype=dtype) mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(spectrum, mel_energies.T) mel_energies = torch.mm(spectrum, mel_energies.T)
...@@ -605,7 +650,7 @@ def fbank(waveform: Tensor, ...@@ -605,7 +650,7 @@ def fbank(waveform: Tensor,
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps) # returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins) # size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, 'ortho') dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a # this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
...@@ -624,32 +669,33 @@ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: ...@@ -624,32 +669,33 @@ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
def mfcc( def mfcc(
waveform: Tensor, waveform: Tensor,
blackman_coeff: float = 0.42, blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0, cepstral_lifter: float = 22.0,
channel: int = -1, channel: int = -1,
dither: float = 0.0, dither: float = 0.0,
energy_floor: float = 1.0, energy_floor: float = 1.0,
frame_length: float = 25.0, frame_length: float = 25.0,
frame_shift: float = 10.0, frame_shift: float = 10.0,
high_freq: float = 0.0, high_freq: float = 0.0,
htk_compat: bool = False, htk_compat: bool = False,
low_freq: float = 20.0, low_freq: float = 20.0,
num_ceps: int = 13, num_ceps: int = 13,
min_duration: float = 0.0, min_duration: float = 0.0,
num_mel_bins: int = 23, num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97, preemphasis_coefficient: float = 0.97,
raw_energy: bool = True, raw_energy: bool = True,
remove_dc_offset: bool = True, remove_dc_offset: bool = True,
round_to_power_of_two: bool = True, round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0, sample_frequency: float = 16000.0,
snip_edges: bool = True, snip_edges: bool = True,
subtract_mean: bool = False, subtract_mean: bool = False,
use_energy: bool = False, use_energy: bool = False,
vtln_high: float = -500.0, vtln_high: float = -500.0,
vtln_low: float = 100.0, vtln_low: float = 100.0,
vtln_warp: float = 1.0, vtln_warp: float = 1.0,
window_type: str = POVEY) -> Tensor: window_type: str = POVEY,
) -> Tensor:
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats. compute-mfcc-feats.
...@@ -697,29 +743,48 @@ def mfcc( ...@@ -697,29 +743,48 @@ def mfcc(
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided where m is calculated in _get_strided
""" """
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (num_ceps, num_mel_bins) assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
device, dtype = waveform.device, waveform.dtype device, dtype = waveform.device, waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted # The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True). # (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy) # size (m, num_mel_bins + use_energy)
feature = fbank(waveform=waveform, blackman_coeff=blackman_coeff, channel=channel, feature = fbank(
dither=dither, energy_floor=energy_floor, frame_length=frame_length, waveform=waveform,
frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat, blackman_coeff=blackman_coeff,
low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins, channel=channel,
preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy, dither=dither,
remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two, energy_floor=energy_floor,
sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False, frame_length=frame_length,
use_energy=use_energy, use_log_fbank=True, use_power=True, frame_shift=frame_shift,
vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type) high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
use_log_fbank=True,
use_power=True,
vtln_high=vtln_high,
vtln_low=vtln_low,
vtln_warp=vtln_warp,
window_type=window_type,
)
if use_energy: if use_energy:
# size (m) # size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1 # offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat) mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)] feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps) # size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
......
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .commonvoice import COMMONVOICE from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS
from .vctk import VCTK_092
from .dr_vctk import DR_VCTK from .dr_vctk import DR_VCTK
from .gtzan import GTZAN from .gtzan import GTZAN
from .yesno import YESNO
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .librimix import LibriMix from .librimix import LibriMix
from .librispeech import LIBRISPEECH
from .libritts import LIBRITTS from .libritts import LIBRITTS
from .ljspeech import LJSPEECH
from .speechcommands import SPEECHCOMMANDS
from .tedlium import TEDLIUM from .tedlium import TEDLIUM
from .vctk import VCTK_092
from .yesno import YESNO
__all__ = [ __all__ = [
......
import os
import csv import csv
import os
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -14,49 +14,28 @@ from torchaudio.datasets.utils import ( ...@@ -14,49 +14,28 @@ from torchaudio.datasets.utils import (
URL = "aew" URL = "aew"
FOLDER_IN_ARCHIVE = "ARCTIC" FOLDER_IN_ARCHIVE = "ARCTIC"
_CHECKSUMS = { _CHECKSUMS = {
"http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": "645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406",
"645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406", "http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": "024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c",
"http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": "2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc",
"024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c", "http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": "d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c",
"http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": "dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff",
"2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc", "http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": "26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904",
"http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": "3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6",
"d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c", "http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": "8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a",
"http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": "3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2",
"dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff", "http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2": "dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f",
"http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2": "3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73",
"26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904", "http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2": "8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717",
"http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2": "b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0",
"3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6", "http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2": "4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a",
"http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2": "c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4",
"8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a", "http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2": "1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b",
"http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": "http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2": "54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1",
"3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2", "http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2": "7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea",
"http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2":
"dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f",
"http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2":
"3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73",
"http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2":
"8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717",
"http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2":
"b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0",
"http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2":
"4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a",
"http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2":
"c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4",
"http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2":
"1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b",
"http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2":
"54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1",
"http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2":
"7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea",
} }
def load_cmuarctic_item(line: str, def load_cmuarctic_item(line: str, path: str, folder_audio: str, ext_audio: str) -> Tuple[Tensor, int, str, str]:
path: str,
folder_audio: str,
ext_audio: str) -> Tuple[Tensor, int, str, str]:
utterance_id, transcript = line[0].strip().split(" ", 2)[1:] utterance_id, transcript = line[0].strip().split(" ", 2)[1:]
...@@ -68,12 +47,7 @@ def load_cmuarctic_item(line: str, ...@@ -68,12 +47,7 @@ def load_cmuarctic_item(line: str,
# Load audio # Load audio
waveform, sample_rate = torchaudio.load(file_audio) waveform, sample_rate = torchaudio.load(file_audio)
return ( return (waveform, sample_rate, transcript, utterance_id.split("_")[1])
waveform,
sample_rate,
transcript,
utterance_id.split("_")[1]
)
class CMUARCTIC(Dataset): class CMUARCTIC(Dataset):
...@@ -98,11 +72,9 @@ class CMUARCTIC(Dataset): ...@@ -98,11 +72,9 @@ class CMUARCTIC(Dataset):
_ext_audio = ".wav" _ext_audio = ".wav"
_folder_audio = "wav" _folder_audio = "wav"
def __init__(self, def __init__(
root: Union[str, Path], self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
url: str = URL, ) -> None:
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
if url in [ if url in [
"aew", "aew",
...@@ -122,7 +94,7 @@ class CMUARCTIC(Dataset): ...@@ -122,7 +94,7 @@ class CMUARCTIC(Dataset):
"rms", "rms",
"rxr", "rxr",
"slp", "slp",
"slt" "slt",
]: ]:
url = "cmu_us_" + url + "_arctic" url = "cmu_us_" + url + "_arctic"
......
...@@ -3,83 +3,83 @@ import re ...@@ -3,83 +3,83 @@ import re
from pathlib import Path from pathlib import Path
from typing import Iterable, Tuple, Union, List from typing import Iterable, Tuple, Union, List
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
_CHECKSUMS = { _CHECKSUMS = {
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4",
"209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols":
"408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027",
} }
_PUNCTUATIONS = set([ _PUNCTUATIONS = set(
"!EXCLAMATION-POINT", [
"\"CLOSE-QUOTE", "!EXCLAMATION-POINT",
"\"DOUBLE-QUOTE", '"CLOSE-QUOTE',
"\"END-OF-QUOTE", '"DOUBLE-QUOTE',
"\"END-QUOTE", '"END-OF-QUOTE',
"\"IN-QUOTES", '"END-QUOTE',
"\"QUOTE", '"IN-QUOTES',
"\"UNQUOTE", '"QUOTE',
"#HASH-MARK", '"UNQUOTE',
"#POUND-SIGN", "#HASH-MARK",
"#SHARP-SIGN", "#POUND-SIGN",
"%PERCENT", "#SHARP-SIGN",
"&AMPERSAND", "%PERCENT",
"'END-INNER-QUOTE", "&AMPERSAND",
"'END-QUOTE", "'END-INNER-QUOTE",
"'INNER-QUOTE", "'END-QUOTE",
"'QUOTE", "'INNER-QUOTE",
"'SINGLE-QUOTE", "'QUOTE",
"(BEGIN-PARENS", "'SINGLE-QUOTE",
"(IN-PARENTHESES", "(BEGIN-PARENS",
"(LEFT-PAREN", "(IN-PARENTHESES",
"(OPEN-PARENTHESES", "(LEFT-PAREN",
"(PAREN", "(OPEN-PARENTHESES",
"(PARENS", "(PAREN",
"(PARENTHESES", "(PARENS",
")CLOSE-PAREN", "(PARENTHESES",
")CLOSE-PARENTHESES", ")CLOSE-PAREN",
")END-PAREN", ")CLOSE-PARENTHESES",
")END-PARENS", ")END-PAREN",
")END-PARENTHESES", ")END-PARENS",
")END-THE-PAREN", ")END-PARENTHESES",
")PAREN", ")END-THE-PAREN",
")PARENS", ")PAREN",
")RIGHT-PAREN", ")PARENS",
")UN-PARENTHESES", ")RIGHT-PAREN",
"+PLUS", ")UN-PARENTHESES",
",COMMA", "+PLUS",
"--DASH", ",COMMA",
"-DASH", "--DASH",
"-HYPHEN", "-DASH",
"...ELLIPSIS", "-HYPHEN",
".DECIMAL", "...ELLIPSIS",
".DOT", ".DECIMAL",
".FULL-STOP", ".DOT",
".PERIOD", ".FULL-STOP",
".POINT", ".PERIOD",
"/SLASH", ".POINT",
":COLON", "/SLASH",
";SEMI-COLON", ":COLON",
";SEMI-COLON(1)", ";SEMI-COLON",
"?QUESTION-MARK", ";SEMI-COLON(1)",
"{BRACE", "?QUESTION-MARK",
"{LEFT-BRACE", "{BRACE",
"{OPEN-BRACE", "{LEFT-BRACE",
"}CLOSE-BRACE", "{OPEN-BRACE",
"}RIGHT-BRACE", "}CLOSE-BRACE",
]) "}RIGHT-BRACE",
]
)
def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]: def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
_alt_re = re.compile(r'\([0-9]+\)') _alt_re = re.compile(r"\([0-9]+\)")
cmudict: List[Tuple[str, List[str]]] = list() cmudict: List[Tuple[str, List[str]]] = list()
for line in lines: for line in lines:
if not line or line.startswith(';;;'): # ignore comments if not line or line.startswith(";;;"): # ignore comments
continue continue
word, phones = line.strip().split(' ') word, phones = line.strip().split(" ")
if word in _PUNCTUATIONS: if word in _PUNCTUATIONS:
if exclude_punctuations: if exclude_punctuations:
continue continue
...@@ -96,7 +96,7 @@ def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[ ...@@ -96,7 +96,7 @@ def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[
# if a word have multiple pronunciations, there will be (number) appended to it # if a word have multiple pronunciations, there will be (number) appended to it
# for example, DATAPOINTS and DATAPOINTS(1), # for example, DATAPOINTS and DATAPOINTS(1),
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
word = re.sub(_alt_re, '', word) word = re.sub(_alt_re, "", word)
phones = phones.split(" ") phones = phones.split(" ")
cmudict.append((word, phones)) cmudict.append((word, phones))
...@@ -121,44 +121,46 @@ class CMUDict(Dataset): ...@@ -121,44 +121,46 @@ class CMUDict(Dataset):
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``) (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
""" """
def __init__(self, def __init__(
root: Union[str, Path], self,
exclude_punctuations: bool = True, root: Union[str, Path],
*, exclude_punctuations: bool = True,
download: bool = False, *,
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b", download: bool = False,
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols", url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
) -> None: url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
) -> None:
self.exclude_punctuations = exclude_punctuations self.exclude_punctuations = exclude_punctuations
self._root_path = Path(root) self._root_path = Path(root)
if not os.path.isdir(self._root_path): if not os.path.isdir(self._root_path):
raise RuntimeError(f'The root directory does not exist; {root}') raise RuntimeError(f"The root directory does not exist; {root}")
dict_file = self._root_path / os.path.basename(url) dict_file = self._root_path / os.path.basename(url)
symbol_file = self._root_path / os.path.basename(url_symbols) symbol_file = self._root_path / os.path.basename(url_symbols)
if not os.path.exists(dict_file): if not os.path.exists(dict_file):
if not download: if not download:
raise RuntimeError( raise RuntimeError(
'The dictionary file is not found in the following location. ' "The dictionary file is not found in the following location. "
f'Set `download=True` to download it. {dict_file}') f"Set `download=True` to download it. {dict_file}"
)
checksum = _CHECKSUMS.get(url, None) checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, dict_file, checksum) download_url_to_file(url, dict_file, checksum)
if not os.path.exists(symbol_file): if not os.path.exists(symbol_file):
if not download: if not download:
raise RuntimeError( raise RuntimeError(
'The symbol file is not found in the following location. ' "The symbol file is not found in the following location. "
f'Set `download=True` to download it. {symbol_file}') f"Set `download=True` to download it. {symbol_file}"
)
checksum = _CHECKSUMS.get(url_symbols, None) checksum = _CHECKSUMS.get(url_symbols, None)
download_url_to_file(url_symbols, symbol_file, checksum) download_url_to_file(url_symbols, symbol_file, checksum)
with open(symbol_file, "r") as text: with open(symbol_file, "r") as text:
self._symbols = [line.strip() for line in text.readlines()] self._symbols = [line.strip() for line in text.readlines()]
with open(dict_file, "r", encoding='latin-1') as text: with open(dict_file, "r", encoding="latin-1") as text:
self._dictionary = _parse_dictionary( self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations)
text.readlines(), exclude_punctuations=self.exclude_punctuations)
def __getitem__(self, n: int) -> Tuple[str, List[str]]: def __getitem__(self, n: int) -> Tuple[str, List[str]]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
...@@ -177,6 +179,5 @@ class CMUDict(Dataset): ...@@ -177,6 +179,5 @@ class CMUDict(Dataset):
@property @property
def symbols(self) -> List[str]: def symbols(self) -> List[str]:
"""list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`. """list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`."""
"""
return self._symbols.copy() return self._symbols.copy()
...@@ -3,17 +3,14 @@ import os ...@@ -3,17 +3,14 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple, Union from typing import List, Dict, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torchaudio
def load_commonvoice_item(line: List[str], def load_commonvoice_item(
header: List[str], line: List[str], header: List[str], path: str, folder_audio: str, ext_audio: str
path: str, ) -> Tuple[Tensor, int, Dict[str, str]]:
folder_audio: str,
ext_audio: str) -> Tuple[Tensor, int, Dict[str, str]]:
# Each line as the following data: # Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent # client_id, path, sentence, up_votes, down_votes, age, gender, accent
...@@ -45,9 +42,7 @@ class COMMONVOICE(Dataset): ...@@ -45,9 +42,7 @@ class COMMONVOICE(Dataset):
_ext_audio = ".mp3" _ext_audio = ".mp3"
_folder_audio = "clips" _folder_audio = "clips"
def __init__(self, def __init__(self, root: Union[str, Path], tsv: str = "train.tsv") -> None:
root: Union[str, Path],
tsv: str = "train.tsv") -> None:
# Get string representation of 'root' in case Path object is passed # Get string representation of 'root' in case Path object is passed
self._path = os.fspath(root) self._path = os.fspath(root)
......
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
......
...@@ -4,8 +4,8 @@ from typing import Tuple, Optional, Union ...@@ -4,8 +4,8 @@ from typing import Tuple, Optional, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -1039,8 +1039,7 @@ class GTZAN(Dataset): ...@@ -1039,8 +1039,7 @@ class GTZAN(Dataset):
self.subset = subset self.subset = subset
assert subset is None or subset in ["training", "validation", "testing"], ( assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from " "When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
+ "{'training', 'validation', 'testing'}."
) )
archive = os.path.basename(url) archive = os.path.basename(url)
...@@ -1055,9 +1054,7 @@ class GTZAN(Dataset): ...@@ -1055,9 +1054,7 @@ class GTZAN(Dataset):
extract_archive(archive) extract_archive(archive)
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
raise RuntimeError( raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
"Dataset not found. Please use `download=True` to download it."
)
if self.subset is None: if self.subset is None:
# Check every subdirectory under dataset root # Check every subdirectory under dataset root
......
...@@ -2,9 +2,8 @@ from pathlib import Path ...@@ -2,9 +2,8 @@ from pathlib import Path
from typing import Union, Tuple, List from typing import Union, Tuple, List
import torch import torch
from torch.utils.data import Dataset
import torchaudio import torchaudio
from torch.utils.data import Dataset
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
...@@ -30,6 +29,7 @@ class LibriMix(Dataset): ...@@ -30,6 +29,7 @@ class LibriMix(Dataset):
Note: Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
""" """
def __init__( def __init__(
self, self,
root: Union[str, Path], root: Union[str, Path],
...@@ -44,9 +44,7 @@ class LibriMix(Dataset): ...@@ -44,9 +44,7 @@ class LibriMix(Dataset):
elif sample_rate == 16000: elif sample_rate == 16000:
self.root = self.root / "wav16k/min" / subset self.root = self.root / "wav16k/min" / subset
else: else:
raise ValueError( raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
f"Unsupported sample rate. Found {sample_rate}."
)
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.task = task self.task = task
self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve() self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
...@@ -70,9 +68,7 @@ class LibriMix(Dataset): ...@@ -70,9 +68,7 @@ class LibriMix(Dataset):
for i, dir_ in enumerate(self.src_dirs): for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename)) src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape: if mixed.shape != src.shape:
raise ValueError( raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src) srcs.append(src)
return self.sample_rate, mixed, srcs return self.sample_rate, mixed, srcs
......
import os import os
from typing import Tuple, Union
from pathlib import Path from pathlib import Path
from typing import Tuple, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -14,27 +13,19 @@ from torchaudio.datasets.utils import ( ...@@ -14,27 +13,19 @@ from torchaudio.datasets.utils import (
URL = "train-clean-100" URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech" FOLDER_IN_ARCHIVE = "LibriSpeech"
_CHECKSUMS = { _CHECKSUMS = {
"http://www.openslr.org/resources/12/dev-clean.tar.gz": "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3",
"76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365",
"http://www.openslr.org/resources/12/dev-other.tar.gz": "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23",
"12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29",
"http://www.openslr.org/resources/12/test-clean.tar.gz": "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2",
"39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf",
"http://www.openslr.org/resources/12/test-other.tar.gz": "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2",
"d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29",
"http://www.openslr.org/resources/12/train-clean-100.tar.gz":
"d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2",
"http://www.openslr.org/resources/12/train-clean-360.tar.gz":
"146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf",
"http://www.openslr.org/resources/12/train-other-500.tar.gz":
"ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2"
} }
def load_librispeech_item(fileid: str, def load_librispeech_item(
path: str, fileid: str, path: str, ext_audio: str, ext_txt: str
ext_audio: str, ) -> Tuple[Tensor, int, str, int, int, int]:
ext_txt: str) -> Tuple[Tensor, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-") speaker_id, chapter_id, utterance_id = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + ext_txt file_text = speaker_id + "-" + chapter_id + ext_txt
...@@ -86,11 +77,9 @@ class LIBRISPEECH(Dataset): ...@@ -86,11 +77,9 @@ class LIBRISPEECH(Dataset):
_ext_txt = ".trans.txt" _ext_txt = ".trans.txt"
_ext_audio = ".flac" _ext_audio = ".flac"
def __init__(self, def __init__(
root: Union[str, Path], self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
url: str = URL, ) -> None:
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
if url in [ if url in [
"dev-clean", "dev-clean",
...@@ -125,7 +114,7 @@ class LIBRISPEECH(Dataset): ...@@ -125,7 +114,7 @@ class LIBRISPEECH(Dataset):
download_url_to_file(url, archive, hash_prefix=checksum) download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive) extract_archive(archive)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio)) self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
......
import os import os
from typing import Tuple, Union
from pathlib import Path from pathlib import Path
from typing import Tuple, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -13,20 +13,13 @@ from torchaudio.datasets.utils import ( ...@@ -13,20 +13,13 @@ from torchaudio.datasets.utils import (
URL = "train-clean-100" URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriTTS" FOLDER_IN_ARCHIVE = "LibriTTS"
_CHECKSUMS = { _CHECKSUMS = {
"http://www.openslr.org/resources/60/dev-clean.tar.gz": "http://www.openslr.org/resources/60/dev-clean.tar.gz": "da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a",
"da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a", "http://www.openslr.org/resources/60/dev-other.tar.gz": "d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c",
"http://www.openslr.org/resources/60/dev-other.tar.gz": "http://www.openslr.org/resources/60/test-clean.tar.gz": "234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5",
"d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c", "http://www.openslr.org/resources/60/test-other.tar.gz": "33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d",
"http://www.openslr.org/resources/60/test-clean.tar.gz": "http://www.openslr.org/resources/60/train-clean-100.tar.gz": "c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b",
"234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5", "http://www.openslr.org/resources/60/train-clean-360.tar.gz": "ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886",
"http://www.openslr.org/resources/60/test-other.tar.gz": "http://www.openslr.org/resources/60/train-other-500.tar.gz": "e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df",
"33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d",
"http://www.openslr.org/resources/60/train-clean-100.tar.gz":
"c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b",
"http://www.openslr.org/resources/60/train-clean-360.tar.gz":
"ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886",
"http://www.openslr.org/resources/60/train-other-500.tar.gz":
"e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df",
} }
...@@ -132,7 +125,7 @@ class LIBRITTS(Dataset): ...@@ -132,7 +125,7 @@ class LIBRITTS(Dataset):
download_url_to_file(url, archive, hash_prefix=checksum) download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive) extract_archive(archive)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio)) self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
......
import os
import csv import csv
from typing import Tuple, Union import os
from pathlib import Path from pathlib import Path
from typing import Tuple, Union
import torchaudio import torchaudio
from torchaudio.datasets.utils import extract_archive
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
_RELEASE_CONFIGS = { _RELEASE_CONFIGS = {
...@@ -32,11 +32,13 @@ class LJSPEECH(Dataset): ...@@ -32,11 +32,13 @@ class LJSPEECH(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``). Whether to download the dataset if it is not found at root path. (default: ``False``).
""" """
def __init__(self, def __init__(
root: Union[str, Path], self,
url: str = _RELEASE_CONFIGS["release1"]["url"], root: Union[str, Path],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], url: str = _RELEASE_CONFIGS["release1"]["url"],
download: bool = False) -> None: folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False,
) -> None:
self._parse_filesystem(root, url, folder_in_archive, download) self._parse_filesystem(root, url, folder_in_archive, download)
...@@ -50,7 +52,7 @@ class LJSPEECH(Dataset): ...@@ -50,7 +52,7 @@ class LJSPEECH(Dataset):
folder_in_archive = basename / folder_in_archive folder_in_archive = basename / folder_in_archive
self._path = root / folder_in_archive self._path = root / folder_in_archive
self._metadata_path = root / basename / 'metadata.csv' self._metadata_path = root / basename / "metadata.csv"
if download: if download:
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
...@@ -59,7 +61,7 @@ class LJSPEECH(Dataset): ...@@ -59,7 +61,7 @@ class LJSPEECH(Dataset):
download_url_to_file(url, archive, hash_prefix=checksum) download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive) extract_archive(archive)
with open(self._metadata_path, "r", newline='') as metadata: with open(self._metadata_path, "r", newline="") as metadata:
flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
self._flist = list(flist) self._flist = list(flist)
......
import os import os
from typing import Tuple, Optional, Union
from pathlib import Path from pathlib import Path
from typing import Tuple, Optional, Union
import torchaudio import torchaudio
from torch.utils.data import Dataset
from torch import Tensor from torch import Tensor
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -16,10 +15,8 @@ URL = "speech_commands_v0.02" ...@@ -16,10 +15,8 @@ URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_" HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_" EXCEPT_FOLDER = "_background_noise_"
_CHECKSUMS = { _CHECKSUMS = {
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d",
"743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d", "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58",
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz":
"af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58",
} }
...@@ -75,17 +72,17 @@ class SPEECHCOMMANDS(Dataset): ...@@ -75,17 +72,17 @@ class SPEECHCOMMANDS(Dataset):
original paper can be found `here <https://arxiv.org/pdf/1804.03209.pdf>`_. (Default: ``None``) original paper can be found `here <https://arxiv.org/pdf/1804.03209.pdf>`_. (Default: ``None``)
""" """
def __init__(self, def __init__(
root: Union[str, Path], self,
url: str = URL, root: Union[str, Path],
folder_in_archive: str = FOLDER_IN_ARCHIVE, url: str = URL,
download: bool = False, folder_in_archive: str = FOLDER_IN_ARCHIVE,
subset: Optional[str] = None, download: bool = False,
) -> None: subset: Optional[str] = None,
) -> None:
assert subset is None or subset in ["training", "validation", "testing"], ( assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from " "When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
+ "{'training', 'validation', 'testing'}."
) )
if url in [ if url in [
...@@ -121,15 +118,14 @@ class SPEECHCOMMANDS(Dataset): ...@@ -121,15 +118,14 @@ class SPEECHCOMMANDS(Dataset):
self._walker = _load_list(self._path, "testing_list.txt") self._walker = _load_list(self._path, "testing_list.txt")
elif subset == "training": elif subset == "training":
excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt")) excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav')) walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
self._walker = [ self._walker = [
w for w in walker w
if HASH_DIVIDER in w for w in walker
and EXCEPT_FOLDER not in w if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and os.path.normpath(w) not in excludes
and os.path.normpath(w) not in excludes
] ]
else: else:
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav')) walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]: def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
......
import os import os
from typing import Tuple, Union
from pathlib import Path from pathlib import Path
from typing import Tuple, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -58,13 +57,14 @@ class TEDLIUM(Dataset): ...@@ -58,13 +57,14 @@ class TEDLIUM(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``). Whether to download the dataset if it is not found at root path. (default: ``False``).
audio_ext (str, optional): extension for audio file (default: ``"audio_ext"``) audio_ext (str, optional): extension for audio file (default: ``"audio_ext"``)
""" """
def __init__( def __init__(
self, self,
root: Union[str, Path], root: Union[str, Path],
release: str = "release1", release: str = "release1",
subset: str = None, subset: str = None,
download: bool = False, download: bool = False,
audio_ext: str = ".sph" audio_ext: str = ".sph",
) -> None: ) -> None:
self._ext_audio = audio_ext self._ext_audio = audio_ext
if release in _RELEASE_CONFIGS.keys(): if release in _RELEASE_CONFIGS.keys():
...@@ -75,14 +75,16 @@ class TEDLIUM(Dataset): ...@@ -75,14 +75,16 @@ class TEDLIUM(Dataset):
# Raise warning # Raise warning
raise RuntimeError( raise RuntimeError(
"The release {} does not match any of the supported tedlium releases{} ".format( "The release {} does not match any of the supported tedlium releases{} ".format(
release, _RELEASE_CONFIGS.keys(), release,
_RELEASE_CONFIGS.keys(),
) )
) )
if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]: if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]:
# Raise warning # Raise warning
raise RuntimeError( raise RuntimeError(
"The subset {} does not match any of the supported tedlium subsets{} ".format( "The subset {} does not match any of the supported tedlium subsets{} ".format(
subset, _RELEASE_CONFIGS[release]["supported_subsets"], subset,
_RELEASE_CONFIGS[release]["supported_subsets"],
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment