Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
# flake8: noqa
# flake8: noqa
from . import utils
from .utils import (
list_audio_backends,
......
......@@ -26,13 +26,14 @@ class AudioMetaData:
* ``HTK``: Single channel 16-bit PCM
* ``UNKNOWN`` : None of above
"""
def __init__(
self,
sample_rate: int,
num_frames: int,
num_channels: int,
bits_per_sample: int,
encoding: str,
self,
sample_rate: int,
num_frames: int,
num_channels: int,
bits_per_sample: int,
encoding: str,
):
self.sample_rate = sample_rate
self.num_frames = num_frames
......
......@@ -4,19 +4,21 @@ from typing import Callable, Optional, Tuple, Union
from torch import Tensor
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')
def load(
filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
filetype: Optional[str] = None,
) -> Tuple[Tensor, int]:
raise RuntimeError("No audio I/O backend is available.")
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
raise RuntimeError('No audio I/O backend is available.')
raise RuntimeError("No audio I/O backend is available.")
def info(filepath: str) -> None:
raise RuntimeError('No audio I/O backend is available.')
raise RuntimeError("No audio I/O backend is available.")
"""The new soundfile backend which will become default in 0.8.0 onward"""
from typing import Tuple, Optional
import warnings
from typing import Tuple, Optional
import torch
from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData
......@@ -19,33 +20,33 @@ if _mod_utils.is_soundfile_available():
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
'PCM_S8': 8, # Signed 8 bit data
'PCM_16': 16, # Signed 16 bit data
'PCM_24': 24, # Signed 24 bit data
'PCM_32': 32, # Signed 32 bit data
'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only)
'FLOAT': 32, # 32 bit float data
'DOUBLE': 64, # 64 bit float data
'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'IMA_ADPCM': 0, # IMA ADPCM.
'MS_ADPCM': 0, # Microsoft ADPCM.
'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
'VOX_ADPCM': 0, # OKI / Dialogix ADPCM
'G721_32': 0, # 32kbs G721 ADPCM encoding.
'G723_24': 0, # 24kbs G723 ADPCM encoding.
'G723_40': 0, # 40kbs G723 ADPCM encoding.
'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding.
'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding.
'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding.
'DWVW_N': 0, # N bit Delta Width Variable Word encoding.
'DPCM_8': 8, # 8 bit differential PCM (XI only)
'DPCM_16': 16, # 16 bit differential PCM (XI only)
'VORBIS': 0, # Xiph Vorbis encoding. (lossy)
'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit).
'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit).
'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit).
'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit).
"PCM_S8": 8, # Signed 8 bit data
"PCM_16": 16, # Signed 16 bit data
"PCM_24": 24, # Signed 24 bit data
"PCM_32": 32, # Signed 32 bit data
"PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
"FLOAT": 32, # 32 bit float data
"DOUBLE": 64, # 64 bit float data
"ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
"IMA_ADPCM": 0, # IMA ADPCM.
"MS_ADPCM": 0, # Microsoft ADPCM.
"GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
"VOX_ADPCM": 0, # OKI / Dialogix ADPCM
"G721_32": 0, # 32kbs G721 ADPCM encoding.
"G723_24": 0, # 24kbs G723 ADPCM encoding.
"G723_40": 0, # 40kbs G723 ADPCM encoding.
"DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
"DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
"DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
"DWVW_N": 0, # N bit Delta Width Variable Word encoding.
"DPCM_8": 8, # 8 bit differential PCM (XI only)
"DPCM_16": 16, # 16 bit differential PCM (XI only)
"VORBIS": 0, # Xiph Vorbis encoding. (lossy)
"ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
"ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
"ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
"ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
}
......@@ -61,23 +62,23 @@ def _get_bit_depth(subtype):
_SUBTYPE_TO_ENCODING = {
'PCM_S8': 'PCM_S',
'PCM_16': 'PCM_S',
'PCM_24': 'PCM_S',
'PCM_32': 'PCM_S',
'PCM_U8': 'PCM_U',
'FLOAT': 'PCM_F',
'DOUBLE': 'PCM_F',
'ULAW': 'ULAW',
'ALAW': 'ALAW',
'VORBIS': 'VORBIS',
"PCM_S8": "PCM_S",
"PCM_16": "PCM_S",
"PCM_24": "PCM_S",
"PCM_32": "PCM_S",
"PCM_U8": "PCM_U",
"FLOAT": "PCM_F",
"DOUBLE": "PCM_F",
"ULAW": "ULAW",
"ALAW": "ALAW",
"VORBIS": "VORBIS",
}
def _get_encoding(format: str, subtype: str):
if format == 'FLAC':
return 'FLAC'
return _SUBTYPE_TO_ENCODING.get(subtype, 'UNKNOWN')
if format == "FLAC":
return "FLAC"
return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
@_mod_utils.requires_soundfile()
......@@ -211,10 +212,7 @@ def load(
return waveform, sample_rate
def _get_subtype_for_wav(
dtype: torch.dtype,
encoding: str,
bits_per_sample: int):
def _get_subtype_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int):
if not encoding:
if not bits_per_sample:
subtype = {
......@@ -271,11 +269,7 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
raise ValueError(f"sph does not support {encoding}.")
def _get_subtype(
dtype: torch.dtype,
format: str,
encoding: str,
bits_per_sample: int):
def _get_subtype(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int):
if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
......@@ -288,8 +282,7 @@ def _get_subtype(
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError(
"ogg/vorbis does not support encoding/bits_per_sample.")
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS"
if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample)
......@@ -407,9 +400,9 @@ def save(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
if hasattr(filepath, 'write'):
if hasattr(filepath, "write"):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
raise RuntimeError("`format` is required when saving to file object.")
ext = format.lower()
else:
ext = str(filepath).split(".")[-1].lower()
......@@ -417,8 +410,10 @@ def save(
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
raise ValueError("Invalid bits_per_sample.")
if bits_per_sample == 24:
warnings.warn("Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this.")
warnings.warn(
"Saving audio with 24 bits per sample might warp samples near -1. "
"Using 16 bits per sample might be able to avoid this."
)
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
......@@ -429,6 +424,4 @@ def save(
if channels_first:
src = src.t()
soundfile.write(
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format
)
soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)
......@@ -2,18 +2,18 @@ import os
from typing import Tuple, Optional
import torch
import torchaudio
from torchaudio._internal import (
module_utils as _mod_utils,
)
import torchaudio
from .common import AudioMetaData
@_mod_utils.requires_sox()
def info(
filepath: str,
format: Optional[str] = None,
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
"""Get signal information of an audio file.
......@@ -46,7 +46,7 @@ def info(
AudioMetaData: Metadata of the given audio.
"""
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
if hasattr(filepath, "read"):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
filepath = os.fspath(filepath)
......@@ -56,12 +56,12 @@ def info(
@_mod_utils.requires_sox()
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
......@@ -145,24 +145,26 @@ def load(
`[channel, time]` else `[time, channel]`.
"""
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
if hasattr(filepath, "read"):
return torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath, frame_offset, num_frames, normalize, channels_first, format
)
filepath = os.fspath(filepath)
return torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath, frame_offset, num_frames, normalize, channels_first, format
)
@_mod_utils.requires_sox()
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.
......@@ -309,11 +311,12 @@ def save(
or ``libmp3lame`` etc.
"""
if not torch.jit.is_scripting():
if hasattr(filepath, 'write'):
if hasattr(filepath, "write"):
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression,
format, encoding, bits_per_sample)
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
return
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
......@@ -4,6 +4,7 @@ from typing import Optional, List
import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from . import (
no_backend,
sox_io_backend,
......@@ -11,9 +12,9 @@ from . import (
)
__all__ = [
'list_audio_backends',
'get_audio_backend',
'set_audio_backend',
"list_audio_backends",
"get_audio_backend",
"set_audio_backend",
]
......@@ -24,10 +25,10 @@ def list_audio_backends() -> List[str]:
List[str]: The list of available backends.
"""
backends = []
if _mod_utils.is_module_available('soundfile'):
backends.append('soundfile')
if _mod_utils.is_module_available("soundfile"):
backends.append("soundfile")
if _mod_utils.is_sox_available():
backends.append('sox_io')
backends.append("sox_io")
return backends
......@@ -40,31 +41,29 @@ def set_audio_backend(backend: Optional[str]):
of the system. If ``None`` is provided the current backend is unassigned.
"""
if backend is not None and backend not in list_audio_backends():
raise RuntimeError(
f'Backend "{backend}" is not one of '
f'available backends: {list_audio_backends()}.')
raise RuntimeError(f'Backend "{backend}" is not one of ' f"available backends: {list_audio_backends()}.")
if backend is None:
module = no_backend
elif backend == 'sox_io':
elif backend == "sox_io":
module = sox_io_backend
elif backend == 'soundfile':
elif backend == "soundfile":
module = soundfile_backend
else:
raise NotImplementedError(f'Unexpected backend "{backend}"')
for func in ['save', 'load', 'info']:
for func in ["save", "load", "info"]:
setattr(torchaudio, func, getattr(module, func))
def _init_audio_backend():
backends = list_audio_backends()
if 'sox_io' in backends:
set_audio_backend('sox_io')
elif 'soundfile' in backends:
set_audio_backend('soundfile')
if "sox_io" in backends:
set_audio_backend("sox_io")
elif "soundfile" in backends:
set_audio_backend("soundfile")
else:
warnings.warn('No audio backend is available.')
warnings.warn("No audio backend is available.")
set_audio_backend(None)
......@@ -77,7 +76,7 @@ def get_audio_backend() -> Optional[str]:
if torchaudio.load == no_backend.load:
return None
if torchaudio.load == sox_io_backend.load:
return 'sox_io'
return "sox_io"
if torchaudio.load == soundfile_backend.load:
return 'soundfile'
raise ValueError('Unknown backend.')
return "soundfile"
raise ValueError("Unknown backend.")
from . import kaldi
__all__ = [
'kaldi',
"kaldi",
]
import math
from typing import Tuple
import math
import torch
from torch import Tensor
import torchaudio
from torch import Tensor
__all__ = [
'get_mel_banks',
'inverse_mel_scale',
'inverse_mel_scale_scalar',
'mel_scale',
'mel_scale_scalar',
'spectrogram',
'fbank',
'mfcc',
'vtln_warp_freq',
'vtln_warp_mel_freq',
"get_mel_banks",
"inverse_mel_scale",
"inverse_mel_scale_scalar",
"mel_scale",
"mel_scale_scalar",
"spectrogram",
"fbank",
"mfcc",
"vtln_warp_freq",
"vtln_warp_mel_freq",
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
......@@ -25,11 +24,11 @@ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
MILLISECONDS_TO_SECONDS = 0.001
# window types
HAMMING = 'hamming'
HANNING = 'hanning'
POVEY = 'povey'
RECTANGULAR = 'rectangular'
BLACKMAN = 'blackman'
HAMMING = "hamming"
HANNING = "hanning"
POVEY = "povey"
RECTANGULAR = "rectangular"
BLACKMAN = "blackman"
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
......@@ -38,8 +37,7 @@ def _get_epsilon(device, dtype):
def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x
"""
r"""Returns the smallest power of 2 that is greater than x"""
return 1 if x == 0 else 2 ** (x - 1).bit_length()
......@@ -85,14 +83,14 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg
return waveform.as_strided(sizes, strides)
def _feature_window_function(window_type: str,
window_size: int,
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size
"""
def _feature_window_function(
window_type: str,
window_size: int,
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
......@@ -106,64 +104,67 @@ def _feature_window_function(window_type: str,
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (blackman_coeff - 0.5 * torch.cos(a * window_function) +
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype)
return (
blackman_coeff
- 0.5 * torch.cos(a * window_function)
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
).to(device=device, dtype=dtype)
else:
raise Exception('Invalid window type ' + window_type)
raise Exception("Invalid window type " + window_type)
def _get_log_energy(strided_input: Tensor,
epsilon: Tensor,
energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)
"""
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
return torch.max(
log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties
"""
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(
waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties"""
channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0)))
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(
waveform), ('choose a window size {} that is [2, {}]'
.format(window_size, len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
window_size, len(waveform)
)
assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, (
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
)
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
return waveform, window_shift, window_size, padded_window_size
def _get_window(waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
def _get_window(
waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
r"""Gets a window and its log energy
Returns:
......@@ -193,20 +194,23 @@ def _get_window(waveform: Tensor,
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (1, 0), mode='replicate').squeeze(0) # size (m, window_size + 1)
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
0
) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size)
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
0
) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode='constant', value=0).squeeze(0)
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
......@@ -224,22 +228,24 @@ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
return tensor
def spectrogram(waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
window_type: str = POVEY) -> Tensor:
def spectrogram(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
window_type: str = POVEY,
) -> Tensor:
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats.
......@@ -278,21 +284,33 @@ def spectrogram(waveform: Tensor,
epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
......@@ -315,12 +333,14 @@ def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def vtln_warp_freq(vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor) -> Tensor:
def vtln_warp_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor,
) -> Tensor:
r"""This computes a VTLN warping function that is not the same as HTK's one,
but has similar inputs (this function has the advantage of never producing
empty bins).
......@@ -357,8 +377,8 @@ def vtln_warp_freq(vtln_low_cutoff: float,
Returns:
Tensor: Freq after vtln warp
"""
assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
......@@ -388,11 +408,14 @@ def vtln_warp_freq(vtln_low_cutoff: float,
return res
def vtln_warp_mel_freq(vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq, high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor) -> Tensor:
def vtln_warp_mel_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq,
high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor,
) -> Tensor:
r"""
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
......@@ -405,25 +428,30 @@ def vtln_warp_mel_freq(vtln_low_cutoff: float,
Returns:
Tensor: ``mel_freq`` after vtln warp
"""
return mel_scale(vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
vtln_warp_factor, inverse_mel_scale(mel_freq)))
def get_mel_banks(num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
return mel_scale(
vtln_warp_freq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
)
)
def get_mel_banks(
num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float,
) -> Tuple[Tensor, Tensor]:
"""
Returns:
(Tensor, Tensor): The tuple consists of ``bins`` (which is
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
center frequencies of bins of size (``num_bins``)).
"""
assert num_bins > 3, 'Must have at least 3 mel bins'
assert num_bins > 3, "Must have at least 3 mel bins"
assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq
......@@ -431,8 +459,9 @@ def get_mel_banks(num_bins: int,
if high_freq <= 0.0:
high_freq += nyquist
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
assert (
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
......@@ -446,10 +475,11 @@ def get_mel_banks(num_bins: int,
if vtln_high < 0.0:
vtln_high += nyquist
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
assert vtln_warp_factor == 1.0 or (
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
vtln_low, vtln_high, low_freq, high_freq
)
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
......@@ -483,32 +513,34 @@ def get_mel_banks(num_bins: int,
return bins, center_freqs
def fbank(waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY) -> Tensor:
def fbank(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
compute-fbank-feats.
......@@ -559,7 +591,8 @@ def fbank(waveform: Tensor,
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
......@@ -567,21 +600,33 @@ def fbank(waveform: Tensor,
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1)
spectrum = torch.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.)
spectrum = spectrum.pow(2.0)
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies, _ = get_mel_banks(
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp
)
mel_energies = mel_energies.to(device=device, dtype=dtype)
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(spectrum, mel_energies.T)
......@@ -605,7 +650,7 @@ def fbank(waveform: Tensor,
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, 'ortho')
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
......@@ -624,32 +669,33 @@ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
def mfcc(
waveform: Tensor,
blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
num_ceps: int = 13,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY) -> Tensor:
waveform: Tensor,
blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
num_ceps: int = 13,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.
......@@ -697,29 +743,48 @@ def mfcc(
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (num_ceps, num_mel_bins)
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
device, dtype = waveform.device, waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(waveform=waveform, blackman_coeff=blackman_coeff, channel=channel,
dither=dither, energy_floor=energy_floor, frame_length=frame_length,
frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat,
low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False,
use_energy=use_energy, use_log_fbank=True, use_power=True,
vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type)
feature = fbank(
waveform=waveform,
blackman_coeff=blackman_coeff,
channel=channel,
dither=dither,
energy_floor=energy_floor,
frame_length=frame_length,
frame_shift=frame_shift,
high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
use_log_fbank=True,
use_power=True,
vtln_high=vtln_high,
vtln_low=vtln_low,
vtln_warp=vtln_warp,
window_type=window_type,
)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset:(num_mel_bins + mel_offset)]
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
......
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS
from .vctk import VCTK_092
from .dr_vctk import DR_VCTK
from .gtzan import GTZAN
from .yesno import YESNO
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .librimix import LibriMix
from .librispeech import LIBRISPEECH
from .libritts import LIBRITTS
from .ljspeech import LJSPEECH
from .speechcommands import SPEECHCOMMANDS
from .tedlium import TEDLIUM
from .vctk import VCTK_092
from .yesno import YESNO
__all__ = [
......
import os
import csv
import os
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -14,49 +14,28 @@ from torchaudio.datasets.utils import (
URL = "aew"
FOLDER_IN_ARCHIVE = "ARCTIC"
_CHECKSUMS = {
"http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2":
"645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406",
"http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2":
"024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c",
"http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2":
"2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc",
"http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2":
"d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c",
"http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2":
"dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff",
"http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2":
"26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904",
"http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2":
"3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6",
"http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2":
"8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a",
"http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2":
"3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2",
"http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2":
"dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f",
"http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2":
"3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73",
"http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2":
"8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717",
"http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2":
"b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0",
"http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2":
"4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a",
"http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2":
"c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4",
"http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2":
"1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b",
"http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2":
"54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1",
"http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2":
"7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea",
"http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": "645cb33c0f0b2ce41384fdd8d3db2c3f5fc15c1e688baeb74d2e08cab18ab406",
"http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": "024664adeb892809d646a3efd043625b46b5bfa3e6189b3500b2d0d59dfab06c",
"http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": "2c55bc3050caa996758869126ad10cf42e1441212111db034b3a45189c18b6fc",
"http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": "d74a950c9739a65f7bfc4dfa6187f2730fa03de5b8eb3f2da97a51b74df64d3c",
"http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": "dd65c3d2907d1ee52f86e44f578319159e60f4bf722a9142be01161d84e330ff",
"http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": "26b91aaf48b2799b2956792b4632c2f926cd0542f402b5452d5adecb60942904",
"http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": "3f16dc3f3b97955ea22623efb33b444341013fc660677b2e170efdcc959fa7c6",
"http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": "8a0ee4e5acbd4b2f61a4fb947c1730ab3adcc9dc50b195981d99391d29928e8a",
"http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": "3fcff629412b57233589cdb058f730594a62c4f3a75c20de14afe06621ef45e2",
"http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2": "dc82e7967cbd5eddbed33074b0699128dbd4482b41711916d58103707e38c67f",
"http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2": "3a37c0e1dfc91e734fdbc88b562d9e2ebca621772402cdc693bbc9b09b211d73",
"http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2": "8029cafce8296f9bed3022c44ef1e7953332b6bf6943c14b929f468122532717",
"http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2": "b23993765cbf2b9e7bbc3c85b6c56eaf292ac81ee4bb887b638a24d104f921a0",
"http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2": "4faf34d71aa7112813252fb20c5433e2fdd9a9de55a00701ffcbf05f24a5991a",
"http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2": "c6dc11235629c58441c071a7ba8a2d067903dfefbaabc4056d87da35b72ecda4",
"http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2": "1fa4271c393e5998d200e56c102ff46fcfea169aaa2148ad9e9469616fbfdd9b",
"http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2": "54345ed55e45c23d419e9a823eef427f1cc93c83a710735ec667d068c916abf1",
"http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2": "7c173297916acf3cc7fcab2713be4c60b27312316765a90934651d367226b4ea",
}
def load_cmuarctic_item(line: str,
path: str,
folder_audio: str,
ext_audio: str) -> Tuple[Tensor, int, str, str]:
def load_cmuarctic_item(line: str, path: str, folder_audio: str, ext_audio: str) -> Tuple[Tensor, int, str, str]:
utterance_id, transcript = line[0].strip().split(" ", 2)[1:]
......@@ -68,12 +47,7 @@ def load_cmuarctic_item(line: str,
# Load audio
waveform, sample_rate = torchaudio.load(file_audio)
return (
waveform,
sample_rate,
transcript,
utterance_id.split("_")[1]
)
return (waveform, sample_rate, transcript, utterance_id.split("_")[1])
class CMUARCTIC(Dataset):
......@@ -98,11 +72,9 @@ class CMUARCTIC(Dataset):
_ext_audio = ".wav"
_folder_audio = "wav"
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
def __init__(
self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
) -> None:
if url in [
"aew",
......@@ -122,7 +94,7 @@ class CMUARCTIC(Dataset):
"rms",
"rxr",
"slp",
"slt"
"slt",
]:
url = "cmu_us_" + url + "_arctic"
......
......@@ -3,83 +3,83 @@ import re
from pathlib import Path
from typing import Iterable, Tuple, Union, List
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
_CHECKSUMS = {
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b":
"209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols":
"408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027",
}
_PUNCTUATIONS = set([
"!EXCLAMATION-POINT",
"\"CLOSE-QUOTE",
"\"DOUBLE-QUOTE",
"\"END-OF-QUOTE",
"\"END-QUOTE",
"\"IN-QUOTES",
"\"QUOTE",
"\"UNQUOTE",
"#HASH-MARK",
"#POUND-SIGN",
"#SHARP-SIGN",
"%PERCENT",
"&AMPERSAND",
"'END-INNER-QUOTE",
"'END-QUOTE",
"'INNER-QUOTE",
"'QUOTE",
"'SINGLE-QUOTE",
"(BEGIN-PARENS",
"(IN-PARENTHESES",
"(LEFT-PAREN",
"(OPEN-PARENTHESES",
"(PAREN",
"(PARENS",
"(PARENTHESES",
")CLOSE-PAREN",
")CLOSE-PARENTHESES",
")END-PAREN",
")END-PARENS",
")END-PARENTHESES",
")END-THE-PAREN",
")PAREN",
")PARENS",
")RIGHT-PAREN",
")UN-PARENTHESES",
"+PLUS",
",COMMA",
"--DASH",
"-DASH",
"-HYPHEN",
"...ELLIPSIS",
".DECIMAL",
".DOT",
".FULL-STOP",
".PERIOD",
".POINT",
"/SLASH",
":COLON",
";SEMI-COLON",
";SEMI-COLON(1)",
"?QUESTION-MARK",
"{BRACE",
"{LEFT-BRACE",
"{OPEN-BRACE",
"}CLOSE-BRACE",
"}RIGHT-BRACE",
])
_PUNCTUATIONS = set(
[
"!EXCLAMATION-POINT",
'"CLOSE-QUOTE',
'"DOUBLE-QUOTE',
'"END-OF-QUOTE',
'"END-QUOTE',
'"IN-QUOTES',
'"QUOTE',
'"UNQUOTE',
"#HASH-MARK",
"#POUND-SIGN",
"#SHARP-SIGN",
"%PERCENT",
"&AMPERSAND",
"'END-INNER-QUOTE",
"'END-QUOTE",
"'INNER-QUOTE",
"'QUOTE",
"'SINGLE-QUOTE",
"(BEGIN-PARENS",
"(IN-PARENTHESES",
"(LEFT-PAREN",
"(OPEN-PARENTHESES",
"(PAREN",
"(PARENS",
"(PARENTHESES",
")CLOSE-PAREN",
")CLOSE-PARENTHESES",
")END-PAREN",
")END-PARENS",
")END-PARENTHESES",
")END-THE-PAREN",
")PAREN",
")PARENS",
")RIGHT-PAREN",
")UN-PARENTHESES",
"+PLUS",
",COMMA",
"--DASH",
"-DASH",
"-HYPHEN",
"...ELLIPSIS",
".DECIMAL",
".DOT",
".FULL-STOP",
".PERIOD",
".POINT",
"/SLASH",
":COLON",
";SEMI-COLON",
";SEMI-COLON(1)",
"?QUESTION-MARK",
"{BRACE",
"{LEFT-BRACE",
"{OPEN-BRACE",
"}CLOSE-BRACE",
"}RIGHT-BRACE",
]
)
def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
_alt_re = re.compile(r'\([0-9]+\)')
_alt_re = re.compile(r"\([0-9]+\)")
cmudict: List[Tuple[str, List[str]]] = list()
for line in lines:
if not line or line.startswith(';;;'): # ignore comments
if not line or line.startswith(";;;"): # ignore comments
continue
word, phones = line.strip().split(' ')
word, phones = line.strip().split(" ")
if word in _PUNCTUATIONS:
if exclude_punctuations:
continue
......@@ -96,7 +96,7 @@ def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[
# if a word have multiple pronunciations, there will be (number) appended to it
# for example, DATAPOINTS and DATAPOINTS(1),
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
word = re.sub(_alt_re, '', word)
word = re.sub(_alt_re, "", word)
phones = phones.split(" ")
cmudict.append((word, phones))
......@@ -121,44 +121,46 @@ class CMUDict(Dataset):
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
"""
def __init__(self,
root: Union[str, Path],
exclude_punctuations: bool = True,
*,
download: bool = False,
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
) -> None:
def __init__(
self,
root: Union[str, Path],
exclude_punctuations: bool = True,
*,
download: bool = False,
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
) -> None:
self.exclude_punctuations = exclude_punctuations
self._root_path = Path(root)
if not os.path.isdir(self._root_path):
raise RuntimeError(f'The root directory does not exist; {root}')
raise RuntimeError(f"The root directory does not exist; {root}")
dict_file = self._root_path / os.path.basename(url)
symbol_file = self._root_path / os.path.basename(url_symbols)
if not os.path.exists(dict_file):
if not download:
raise RuntimeError(
'The dictionary file is not found in the following location. '
f'Set `download=True` to download it. {dict_file}')
"The dictionary file is not found in the following location. "
f"Set `download=True` to download it. {dict_file}"
)
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, dict_file, checksum)
if not os.path.exists(symbol_file):
if not download:
raise RuntimeError(
'The symbol file is not found in the following location. '
f'Set `download=True` to download it. {symbol_file}')
"The symbol file is not found in the following location. "
f"Set `download=True` to download it. {symbol_file}"
)
checksum = _CHECKSUMS.get(url_symbols, None)
download_url_to_file(url_symbols, symbol_file, checksum)
with open(symbol_file, "r") as text:
self._symbols = [line.strip() for line in text.readlines()]
with open(dict_file, "r", encoding='latin-1') as text:
self._dictionary = _parse_dictionary(
text.readlines(), exclude_punctuations=self.exclude_punctuations)
with open(dict_file, "r", encoding="latin-1") as text:
self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations)
def __getitem__(self, n: int) -> Tuple[str, List[str]]:
"""Load the n-th sample from the dataset.
......@@ -177,6 +179,5 @@ class CMUDict(Dataset):
@property
def symbols(self) -> List[str]:
"""list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`.
"""
"""list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`."""
return self._symbols.copy()
......@@ -3,17 +3,14 @@ import os
from pathlib import Path
from typing import List, Dict, Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
import torchaudio
def load_commonvoice_item(line: List[str],
header: List[str],
path: str,
folder_audio: str,
ext_audio: str) -> Tuple[Tensor, int, Dict[str, str]]:
def load_commonvoice_item(
line: List[str], header: List[str], path: str, folder_audio: str, ext_audio: str
) -> Tuple[Tensor, int, Dict[str, str]]:
# Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent
......@@ -45,9 +42,7 @@ class COMMONVOICE(Dataset):
_ext_audio = ".mp3"
_folder_audio = "clips"
def __init__(self,
root: Union[str, Path],
tsv: str = "train.tsv") -> None:
def __init__(self, root: Union[str, Path], tsv: str = "train.tsv") -> None:
# Get string representation of 'root' in case Path object is passed
self._path = os.fspath(root)
......
from pathlib import Path
from typing import Dict, Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......
......@@ -4,8 +4,8 @@ from typing import Tuple, Optional, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -1039,8 +1039,7 @@ class GTZAN(Dataset):
self.subset = subset
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from "
+ "{'training', 'validation', 'testing'}."
"When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
)
archive = os.path.basename(url)
......@@ -1055,9 +1054,7 @@ class GTZAN(Dataset):
extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
if self.subset is None:
# Check every subdirectory under dataset root
......
......@@ -2,9 +2,8 @@ from pathlib import Path
from typing import Union, Tuple, List
import torch
from torch.utils.data import Dataset
import torchaudio
from torch.utils.data import Dataset
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
......@@ -30,6 +29,7 @@ class LibriMix(Dataset):
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
"""
def __init__(
self,
root: Union[str, Path],
......@@ -44,9 +44,7 @@ class LibriMix(Dataset):
elif sample_rate == 16000:
self.root = self.root / "wav16k/min" / subset
else:
raise ValueError(
f"Unsupported sample rate. Found {sample_rate}."
)
raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
self.sample_rate = sample_rate
self.task = task
self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
......@@ -70,9 +68,7 @@ class LibriMix(Dataset):
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
srcs.append(src)
return self.sample_rate, mixed, srcs
......
import os
from typing import Tuple, Union
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -14,27 +13,19 @@ from torchaudio.datasets.utils import (
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
_CHECKSUMS = {
"http://www.openslr.org/resources/12/dev-clean.tar.gz":
"76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3",
"http://www.openslr.org/resources/12/dev-other.tar.gz":
"12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365",
"http://www.openslr.org/resources/12/test-clean.tar.gz":
"39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23",
"http://www.openslr.org/resources/12/test-other.tar.gz":
"d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29",
"http://www.openslr.org/resources/12/train-clean-100.tar.gz":
"d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2",
"http://www.openslr.org/resources/12/train-clean-360.tar.gz":
"146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf",
"http://www.openslr.org/resources/12/train-other-500.tar.gz":
"ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2"
"http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3",
"http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365",
"http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23",
"http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29",
"http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2",
"http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf",
"http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2",
}
def load_librispeech_item(fileid: str,
path: str,
ext_audio: str,
ext_txt: str) -> Tuple[Tensor, int, str, int, int, int]:
def load_librispeech_item(
fileid: str, path: str, ext_audio: str, ext_txt: str
) -> Tuple[Tensor, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + ext_txt
......@@ -86,11 +77,9 @@ class LIBRISPEECH(Dataset):
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
def __init__(
self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False
) -> None:
if url in [
"dev-clean",
......@@ -125,7 +114,7 @@ class LIBRISPEECH(Dataset):
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio))
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
......
import os
from typing import Tuple, Union
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -13,20 +13,13 @@ from torchaudio.datasets.utils import (
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriTTS"
_CHECKSUMS = {
"http://www.openslr.org/resources/60/dev-clean.tar.gz":
"da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a",
"http://www.openslr.org/resources/60/dev-other.tar.gz":
"d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c",
"http://www.openslr.org/resources/60/test-clean.tar.gz":
"234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5",
"http://www.openslr.org/resources/60/test-other.tar.gz":
"33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d",
"http://www.openslr.org/resources/60/train-clean-100.tar.gz":
"c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b",
"http://www.openslr.org/resources/60/train-clean-360.tar.gz":
"ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886",
"http://www.openslr.org/resources/60/train-other-500.tar.gz":
"e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df",
"http://www.openslr.org/resources/60/dev-clean.tar.gz": "da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a",
"http://www.openslr.org/resources/60/dev-other.tar.gz": "d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c",
"http://www.openslr.org/resources/60/test-clean.tar.gz": "234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5",
"http://www.openslr.org/resources/60/test-other.tar.gz": "33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d",
"http://www.openslr.org/resources/60/train-clean-100.tar.gz": "c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b",
"http://www.openslr.org/resources/60/train-clean-360.tar.gz": "ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886",
"http://www.openslr.org/resources/60/train-other-500.tar.gz": "e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df",
}
......@@ -132,7 +125,7 @@ class LIBRITTS(Dataset):
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio))
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
"""Load the n-th sample from the dataset.
......
import os
import csv
from typing import Tuple, Union
import os
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torchaudio.datasets.utils import extract_archive
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
_RELEASE_CONFIGS = {
......@@ -32,11 +32,13 @@ class LJSPEECH(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def __init__(self,
root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False) -> None:
def __init__(
self,
root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False,
) -> None:
self._parse_filesystem(root, url, folder_in_archive, download)
......@@ -50,7 +52,7 @@ class LJSPEECH(Dataset):
folder_in_archive = basename / folder_in_archive
self._path = root / folder_in_archive
self._metadata_path = root / basename / 'metadata.csv'
self._metadata_path = root / basename / "metadata.csv"
if download:
if not os.path.isdir(self._path):
......@@ -59,7 +61,7 @@ class LJSPEECH(Dataset):
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
with open(self._metadata_path, "r", newline='') as metadata:
with open(self._metadata_path, "r", newline="") as metadata:
flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
self._flist = list(flist)
......
import os
from typing import Tuple, Optional, Union
from pathlib import Path
from typing import Tuple, Optional, Union
import torchaudio
from torch.utils.data import Dataset
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -16,10 +15,8 @@ URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
_CHECKSUMS = {
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz":
"743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d",
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz":
"af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58",
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d",
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58",
}
......@@ -75,17 +72,17 @@ class SPEECHCOMMANDS(Dataset):
original paper can be found `here <https://arxiv.org/pdf/1804.03209.pdf>`_. (Default: ``None``)
"""
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
subset: Optional[str] = None,
) -> None:
def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
subset: Optional[str] = None,
) -> None:
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from "
+ "{'training', 'validation', 'testing'}."
"When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
)
if url in [
......@@ -121,15 +118,14 @@ class SPEECHCOMMANDS(Dataset):
self._walker = _load_list(self._path, "testing_list.txt")
elif subset == "training":
excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
self._walker = [
w for w in walker
if HASH_DIVIDER in w
and EXCEPT_FOLDER not in w
and os.path.normpath(w) not in excludes
w
for w in walker
if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and os.path.normpath(w) not in excludes
]
else:
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav"))
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
......
import os
from typing import Tuple, Union
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -58,13 +57,14 @@ class TEDLIUM(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``).
audio_ext (str, optional): extension for audio file (default: ``"audio_ext"``)
"""
def __init__(
self,
root: Union[str, Path],
release: str = "release1",
subset: str = None,
download: bool = False,
audio_ext: str = ".sph"
audio_ext: str = ".sph",
) -> None:
self._ext_audio = audio_ext
if release in _RELEASE_CONFIGS.keys():
......@@ -75,14 +75,16 @@ class TEDLIUM(Dataset):
# Raise warning
raise RuntimeError(
"The release {} does not match any of the supported tedlium releases{} ".format(
release, _RELEASE_CONFIGS.keys(),
release,
_RELEASE_CONFIGS.keys(),
)
)
if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]:
# Raise warning
raise RuntimeError(
"The subset {} does not match any of the supported tedlium subsets{} ".format(
subset, _RELEASE_CONFIGS[release]["supported_subsets"],
subset,
_RELEASE_CONFIGS[release]["supported_subsets"],
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment