Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
......@@ -4,17 +4,16 @@ import os
import tarfile
import urllib
import urllib.request
import zipfile
import warnings
import zipfile
from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
def stream_url(url: str,
start_byte: Optional[int] = None,
block_size: int = 32 * 1024,
progress_bar: bool = True) -> Iterable:
def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
) -> Iterable:
"""Stream url by chunk
Args:
......@@ -36,11 +35,11 @@ def stream_url(url: str,
req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:
num_bytes = 0
......@@ -53,13 +52,15 @@ def stream_url(url: str,
pbar.update(len(chunk))
def download_url(url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False) -> None:
def download_url(
url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False,
) -> None:
"""Download file to disk.
Args:
......@@ -84,9 +85,7 @@ def download_url(url: str,
local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError(
"{} already exists. Delete the file manually and retry.".format(filepath)
)
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
else:
mode = "wb"
local_size = None
......@@ -95,11 +94,7 @@ def download_url(url: str,
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
......@@ -107,11 +102,7 @@ def download_url(url: str,
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
......
import os
from typing import Tuple
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
_CHECKSUMS = {
"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip":
"f96258be9fdc2cbff6559541aae7ea4f59df3fcaf5cf963aae5ca647357e359c"
"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "f96258be9fdc2cbff6559541aae7ea4f59df3fcaf5cf963aae5ca647357e359c"
}
......@@ -41,17 +38,15 @@ class VCTK_092(Dataset):
"""
def __init__(
self,
root: str,
mic_id: str = "mic2",
download: bool = False,
url: str = URL,
audio_ext=".flac",
self,
root: str,
mic_id: str = "mic2",
download: bool = False,
url: str = URL,
audio_ext=".flac",
):
if mic_id not in ["mic1", "mic2"]:
raise RuntimeError(
f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}'
)
raise RuntimeError(f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}')
archive = os.path.join(root, "VCTK-Corpus-0.92.zip")
......@@ -69,9 +64,7 @@ class VCTK_092(Dataset):
extract_archive(archive, self._path)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
# Extracting speaker IDs from the folder structure
self._speaker_ids = sorted(os.listdir(self._txt_dir))
......@@ -91,9 +84,7 @@ class VCTK_092(Dataset):
if speaker_id == "p280" and mic_id == "mic2":
continue
utterance_dir = os.path.join(self._txt_dir, speaker_id)
for utterance_file in sorted(
f for f in os.listdir(utterance_dir) if f.endswith(".txt")
):
for utterance_file in sorted(f for f in os.listdir(utterance_dir) if f.endswith(".txt")):
utterance_id = os.path.splitext(utterance_file)[0]
audio_path_mic = os.path.join(
self._audio_dir,
......@@ -112,9 +103,7 @@ class VCTK_092(Dataset):
return torchaudio.load(file_path)
def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType:
transcript_path = os.path.join(
self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt"
)
transcript_path = os.path.join(self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt")
audio_path = os.path.join(
self._audio_dir,
speaker_id,
......
......@@ -2,11 +2,10 @@ import os
from pathlib import Path
from typing import List, Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file
import torchaudio
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
extract_archive,
)
......@@ -39,7 +38,7 @@ class YESNO(Dataset):
root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False
download: bool = False,
) -> None:
self._parse_filesystem(root, url, folder_in_archive, download)
......@@ -58,9 +57,7 @@ class YESNO(Dataset):
extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
......
from .filtering import (
allpass_biquad,
band_biquad,
bandpass_biquad,
bandreject_biquad,
bass_biquad,
biquad,
contrast,
dither,
dcshift,
deemph_biquad,
equalizer_biquad,
filtfilt,
flanger,
gain,
highpass_biquad,
lfilter,
lowpass_biquad,
overdrive,
phaser,
riaa_biquad,
treble_biquad,
vad,
)
from .functional import (
amplitude_to_DB,
compute_deltas,
......@@ -23,75 +47,51 @@ from .functional import (
pitch_shift,
rnnt_loss,
)
from .filtering import (
allpass_biquad,
band_biquad,
bandpass_biquad,
bandreject_biquad,
bass_biquad,
biquad,
contrast,
dither,
dcshift,
deemph_biquad,
equalizer_biquad,
filtfilt,
flanger,
gain,
highpass_biquad,
lfilter,
lowpass_biquad,
overdrive,
phaser,
riaa_biquad,
treble_biquad,
vad,
)
__all__ = [
'amplitude_to_DB',
'compute_deltas',
'compute_kaldi_pitch',
'create_dct',
'melscale_fbanks',
'linear_fbanks',
'DB_to_amplitude',
'detect_pitch_frequency',
'griffinlim',
'mask_along_axis',
'mask_along_axis_iid',
'mu_law_encoding',
'mu_law_decoding',
'phase_vocoder',
'sliding_window_cmn',
'spectrogram',
'inverse_spectrogram',
'spectral_centroid',
'allpass_biquad',
'band_biquad',
'bandpass_biquad',
'bandreject_biquad',
'bass_biquad',
'biquad',
'contrast',
'dither',
'dcshift',
'deemph_biquad',
'equalizer_biquad',
'filtfilt',
'flanger',
'gain',
'highpass_biquad',
'lfilter',
'lowpass_biquad',
'overdrive',
'phaser',
'riaa_biquad',
'treble_biquad',
'vad',
'apply_codec',
'resample',
'edit_distance',
'pitch_shift',
'rnnt_loss',
"amplitude_to_DB",
"compute_deltas",
"compute_kaldi_pitch",
"create_dct",
"melscale_fbanks",
"linear_fbanks",
"DB_to_amplitude",
"detect_pitch_frequency",
"griffinlim",
"mask_along_axis",
"mask_along_axis_iid",
"mu_law_encoding",
"mu_law_decoding",
"phase_vocoder",
"sliding_window_cmn",
"spectrogram",
"inverse_spectrogram",
"spectral_centroid",
"allpass_biquad",
"band_biquad",
"bandpass_biquad",
"bandreject_biquad",
"bass_biquad",
"biquad",
"contrast",
"dither",
"dcshift",
"deemph_biquad",
"equalizer_biquad",
"filtfilt",
"flanger",
"gain",
"highpass_biquad",
"lfilter",
"lowpass_biquad",
"overdrive",
"phaser",
"riaa_biquad",
"treble_biquad",
"vad",
"apply_codec",
"resample",
"edit_distance",
"pitch_shift",
"rnnt_loss",
]
......@@ -45,7 +45,7 @@ def _generate_wave_table(
d = (torch.sin(point.to(torch.float64) / table_size * 2 * math.pi) + 1) / 2
elif wave_type == "TRIANGLE":
d = point.to(torch.float64) * 2 / table_size
value = torch.div(4 * point, table_size, rounding_mode='floor')
value = torch.div(4 * point, table_size, rounding_mode="floor")
d[value == 0] = d[value == 0] + 0.5
d[value == 1] = 1.5 - d[value == 1]
d[value == 2] = 1.5 - d[value == 2]
......@@ -64,9 +64,7 @@ def _generate_wave_table(
return d
def allpass_biquad(
waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707
) -> Tensor:
def allpass_biquad(waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707) -> Tensor:
r"""Design two-pole all-pass filter. Similar to SoX implementation.
Args:
......@@ -191,9 +189,7 @@ def bandpass_biquad(
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def bandreject_biquad(
waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707
) -> Tensor:
def bandreject_biquad(waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707) -> Tensor:
r"""Design two-pole band-reject filter. Similar to SoX implementation.
Args:
......@@ -273,9 +269,7 @@ def bass_biquad(
return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0)
def biquad(
waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: float, a2: float
) -> Tensor:
def biquad(waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: float, a2: float) -> Tensor:
r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
https://en.wikipedia.org/wiki/Digital_biquad_filter
......@@ -339,9 +333,7 @@ def contrast(waveform: Tensor, enhancement_amount: float = 75.0) -> Tensor:
return output_waveform
def dcshift(
waveform: Tensor, shift: float, limiter_gain: Optional[float] = None
) -> Tensor:
def dcshift(waveform: Tensor, shift: float, limiter_gain: Optional[float] = None) -> Tensor:
r"""Apply a DC shift to the audio. Similar to SoX implementation.
This can be useful to remove a DC offset
(caused perhaps by a hardware problem in the recording chain) from the audio
......@@ -367,25 +359,13 @@ def dcshift(
if limiter_gain is not None and shift > 0:
mask = waveform > limiter_threshold
temp = (
(waveform[mask] - limiter_threshold)
* limiter_gain
/ (1 - limiter_threshold)
)
output_waveform[mask] = (temp + limiter_threshold + shift).clamp(
max=limiter_threshold
)
temp = (waveform[mask] - limiter_threshold) * limiter_gain / (1 - limiter_threshold)
output_waveform[mask] = (temp + limiter_threshold + shift).clamp(max=limiter_threshold)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
elif limiter_gain is not None and shift < 0:
mask = waveform < -limiter_threshold
temp = (
(waveform[mask] + limiter_threshold)
* limiter_gain
/ (1 - limiter_threshold)
)
output_waveform[mask] = (temp - limiter_threshold + shift).clamp(
min=-limiter_threshold
)
temp = (waveform[mask] + limiter_threshold) * limiter_gain / (1 - limiter_threshold)
output_waveform[mask] = (temp - limiter_threshold + shift).clamp(min=-limiter_threshold)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
else:
output_waveform = (waveform + shift).clamp(min=-1, max=1)
......@@ -461,9 +441,7 @@ def _add_noise_shaping(dithered_waveform: Tensor, waveform: Tensor) -> Tensor:
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution(
waveform: Tensor, density_function: str = "TPDF"
) -> Tensor:
def _apply_probability_distribution(waveform: Tensor, density_function: str = "TPDF") -> Tensor:
r"""Apply a probability distribution function on a waveform.
Triangular probability density function (TPDF) dither noise has a
......@@ -561,9 +539,7 @@ def _apply_probability_distribution(
signal_scaled_dis = signal_scaled + gaussian
else:
# dtype needed for https://github.com/pytorch/pytorch/issues/32358
TPDF = torch.bartlett_window(
time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device
)
TPDF = torch.bartlett_window(time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device)
TPDF = TPDF.repeat((channel_size + 1), 1)
signal_scaled_dis = signal_scaled + TPDF
......@@ -574,9 +550,7 @@ def _apply_probability_distribution(
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
def dither(
waveform: Tensor, density_function: str = "TPDF", noise_shaping: bool = False
) -> Tensor:
def dither(waveform: Tensor, density_function: str = "TPDF", noise_shaping: bool = False) -> Tensor:
r"""Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
......@@ -594,9 +568,7 @@ def dither(
Returns:
Tensor: waveform dithered
"""
dithered = _apply_probability_distribution(
waveform, density_function=density_function
)
dithered = _apply_probability_distribution(waveform, density_function=density_function)
if noise_shaping:
return _add_noise_shaping(dithered, waveform)
......@@ -643,7 +615,10 @@ def equalizer_biquad(
def filtfilt(
waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True,
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor:
r"""Apply an IIR filter forward and backward to a waveform.
......@@ -667,7 +642,11 @@ def filtfilt(
"""
forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
backward_filtered = lfilter(
forward_filtered.flip(-1), a_coeffs, b_coeffs, clamp=clamp, batching=True,
forward_filtered.flip(-1),
a_coeffs,
b_coeffs,
clamp=clamp,
batching=True,
).flip(-1)
return backward_filtered
......@@ -757,9 +736,7 @@ def flanger(
delay_buf_length = int((delay_min + delay_depth) * sample_rate + 0.5)
delay_buf_length = delay_buf_length + 2
delay_bufs = torch.zeros(
waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device
)
delay_bufs = torch.zeros(waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device)
delay_last = torch.zeros(waveform.shape[0], n_channels, dtype=dtype, device=device)
lfo_length = int(sample_rate / speed)
......@@ -787,9 +764,7 @@ def flanger(
delay_buf_pos = (delay_buf_pos + delay_buf_length - 1) % delay_buf_length
cur_channel_phase = (channel_idxs * lfo_length * channel_phase + 0.5).to(
torch.int64
)
cur_channel_phase = (channel_idxs * lfo_length * channel_phase + 0.5).to(torch.int64)
delay_tensor = lfo[(lfo_pos + cur_channel_phase) % lfo_length]
frac_delay = torch.frac(delay_tensor)
delay_tensor = torch.floor(delay_tensor)
......@@ -800,24 +775,18 @@ def flanger(
delay_bufs[:, :, delay_buf_pos] = temp + delay_last * feedback_gain
delayed_0 = delay_bufs[
:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length
]
delayed_0 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
int_delay = int_delay + 1
delayed_1 = delay_bufs[
:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length
]
delayed_1 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
int_delay = int_delay + 1
if interpolation == "linear":
delayed = delayed_0 + (delayed_1 - delayed_0) * frac_delay
else:
delayed_2 = delay_bufs[
:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length
]
delayed_2 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
int_delay = int_delay + 1
......@@ -854,9 +823,7 @@ def gain(waveform: Tensor, gain_db: float = 1.0) -> Tensor:
return waveform * ratio
def highpass_biquad(
waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707
) -> Tensor:
def highpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args:
......@@ -889,9 +856,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
n_order = a_coeffs_flipped.size(1)
a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2)
for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)):
windowed_output_signal = padded_output_waveform[
:, :, i_sample:i_sample + n_order
]
windowed_output_signal = padded_output_waveform[:, :, i_sample : i_sample + n_order]
o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t()
padded_output_waveform[:, :, i_sample + n_order - 1] = o0
......@@ -899,7 +864,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
try:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter_core_loop'
assert str(err) == "No such operator torchaudio::_lfilter_core_loop"
_lfilter_core_cpu_loop = _lfilter_core_generic_loop
......@@ -929,40 +894,32 @@ def _lfilter_core(
b_coeffs_flipped = b_coeffs.flip(1)
# calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d(
padded_waveform,
b_coeffs_flipped.unsqueeze(1),
groups=n_channel
)
input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)
input_signal_windows.div_(a_coeffs[:, :1])
a_coeffs_flipped.div_(a_coeffs[:, :1])
if input_signal_windows.device == torch.device('cpu') and\
a_coeffs_flipped.device == torch.device('cpu') and\
padded_output_waveform.device == torch.device('cpu'):
if (
input_signal_windows.device == torch.device("cpu")
and a_coeffs_flipped.device == torch.device("cpu")
and padded_output_waveform.device == torch.device("cpu")
):
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, :, n_order - 1:]
output = padded_output_waveform[:, :, n_order - 1 :]
return output
try:
_lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter'
assert str(err) == "No such operator torchaudio::_lfilter"
_lfilter = _lfilter_core
def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
batching: bool = True
) -> Tensor:
def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
Note:
......@@ -1016,9 +973,7 @@ def lfilter(
return output
def lowpass_biquad(
waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707
) -> Tensor:
def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args:
......@@ -1048,11 +1003,7 @@ def lowpass_biquad(
def _overdrive_core_loop_generic(
waveform: Tensor,
temp: Tensor,
last_in: Tensor,
last_out: Tensor,
output_waveform: Tensor
waveform: Tensor, temp: Tensor, last_in: Tensor, last_out: Tensor, output_waveform: Tensor
):
for i in range(waveform.shape[-1]):
last_out = temp[:, i] - last_in + 0.995 * last_out
......@@ -1063,7 +1014,7 @@ def _overdrive_core_loop_generic(
try:
_overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_overdrive_core_loop'
assert str(err) == "No such operator torchaudio::_overdrive_core_loop"
_overdrive_core_loop_cpu = _overdrive_core_loop_generic
......@@ -1110,7 +1061,7 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
# Uses CPU optimized loop function if available for CPU device
if device == torch.device('cpu'):
if device == torch.device("cpu"):
_overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform)
else:
_overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform)
......@@ -1164,9 +1115,7 @@ def phaser(
waveform = waveform.view(-1, actual_shape[-1])
delay_buf_len = int((delay_ms * 0.001 * sample_rate) + 0.5)
delay_buf = torch.zeros(
waveform.shape[0], delay_buf_len, dtype=dtype, device=device
)
delay_buf = torch.zeros(waveform.shape[0], delay_buf_len, dtype=dtype, device=device)
mod_buf_len = int(sample_rate / mod_speed + 0.5)
......@@ -1203,9 +1152,7 @@ def phaser(
delay_buf_list[delay_pos] = temp * decay
output_waveform_pre_gain_list.append(temp)
output_waveform = torch.stack(output_waveform_pre_gain_list, dim=1).to(
dtype=dtype, device=device
)
output_waveform = torch.stack(output_waveform_pre_gain_list, dim=1).to(dtype=dtype, device=device)
output_waveform.mul_(gain_out)
return output_waveform.clamp(min=-1, max=1).view(actual_shape)
......@@ -1344,9 +1291,7 @@ def _measure(
dftBuf = torch.zeros(dft_len_ws)
_index_ns = torch.tensor(
[index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)]
)
_index_ns = torch.tensor([index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)])
dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws]
# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
......@@ -1358,9 +1303,7 @@ def _measure(
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_()
mult: float = (
boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
)
mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
_d = _dftBuf[spectrum_start:spectrum_end].abs()
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
......@@ -1387,17 +1330,13 @@ def _measure(
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()
_cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_()
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf = torch.fft.rfft(_cepstrum_Buf)
result: float = float(
torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2))
)
result = (
math.log(result / (cepstrum_end - cepstrum_start)) if result > 0 else -math.inf
)
result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
result = math.log(result / (cepstrum_end - cepstrum_start)) if result > 0 else -math.inf
return max(0, 21 + result)
......@@ -1489,9 +1428,7 @@ def vad(
" and https://github.com/pytorch/audio/issues/1468."
)
measure_duration: float = (
2.0 / measure_freq if measure_duration is None else measure_duration
)
measure_duration: float = 2.0 / measure_freq if measure_duration is None else measure_duration
measure_len_ws = int(sample_rate * measure_duration + 0.5)
measure_len_ns = measure_len_ws
......@@ -1506,9 +1443,7 @@ def vad(
gap_len = int(allowed_gap * measure_freq + 0.5)
fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5)
samplesLen_ns = (
fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
)
samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
spectrum_window = torch.zeros(measure_len_ws)
for i in range(measure_len_ws):
......@@ -1526,9 +1461,7 @@ def vad(
for i in range(spectrum_end - spectrum_start):
cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start)
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
cepstrum_window *= torch.hann_window(
spectrum_end - spectrum_start, dtype=torch.float
)
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)
cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq)
cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
......@@ -1567,9 +1500,7 @@ def vad(
samples[i, samplesIndex_ns] = waveform[i, pos]
# if (!p->measure_timer_ns) {
if measure_timer_ns == 0:
index_ns: int = (
samplesIndex_ns + samplesLen_ns - measure_len_ns
) % samplesLen_ns
index_ns: int = (samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
meas: float = _measure(
measure_len_ws=measure_len_ws,
samples=samples[i],
......@@ -1589,9 +1520,7 @@ def vad(
boot_count=boot_count,
)
measures[i, measures_index] = meas
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (
1.0 - trigger_meas_time_mult
)
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
if has_triggered:
......@@ -1602,9 +1531,7 @@ def vad(
j: int = 0
for j in range(n):
if (measures[i, k] >= trigger_level) and (
j <= jTrigger + gap_len
):
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
jZero = jTrigger = j
elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j
......@@ -1631,6 +1558,6 @@ def vad(
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns
res = waveform[:, pos - samplesLen_ns + flushedLen_ns:]
res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
# unpack batch
return res.view(shape[:-1] + res.shape[-1:])
# -*- coding: utf-8 -*-
from collections.abc import Sequence
import io
import math
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple
import torch
import torchaudio
from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils
import torchaudio
__all__ = [
"spectrogram",
......@@ -28,9 +28,9 @@ __all__ = [
"mu_law_encoding",
"mu_law_decoding",
"phase_vocoder",
'mask_along_axis',
'mask_along_axis_iid',
'sliding_window_cmn',
"mask_along_axis",
"mask_along_axis_iid",
"sliding_window_cmn",
"spectral_centroid",
"apply_codec",
"resample",
......@@ -41,18 +41,18 @@ __all__ = [
def spectrogram(
waveform: Tensor,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: Optional[float],
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: Optional[bool] = None,
waveform: Tensor,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: Optional[float],
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: Optional[bool] = None,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
......@@ -116,7 +116,7 @@ def spectrogram(
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
if normalized:
spec_f /= window.pow(2.).sum().sqrt()
spec_f /= window.pow(2.0).sum().sqrt()
if power is not None:
if power == 1.0:
return spec_f.abs()
......@@ -125,17 +125,17 @@ def spectrogram(
def inverse_spectrogram(
spectrogram: Tensor,
length: Optional[int],
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
spectrogram: Tensor,
length: Optional[int],
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
) -> Tensor:
r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided
complex-valued spectrogram.
......@@ -166,7 +166,7 @@ def inverse_spectrogram(
raise ValueError("Expected `spectrogram` to be complex dtype.")
if normalized:
spectrogram = spectrogram * window.pow(2.).sum().sqrt()
spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
# pack batch
shape = spectrogram.size()
......@@ -203,20 +203,20 @@ def _get_complex_dtype(real_dtype: torch.dtype):
return torch.cfloat
if real_dtype == torch.half:
return torch.complex32
raise ValueError(f'Unexpected dtype {real_dtype}')
raise ValueError(f"Unexpected dtype {real_dtype}")
def griffinlim(
specgram: Tensor,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: float,
n_iter: int,
momentum: float,
length: Optional[int],
rand_init: bool
specgram: Tensor,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: float,
n_iter: int,
momentum: float,
length: Optional[int],
rand_init: bool,
) -> Tensor:
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
......@@ -244,8 +244,8 @@ def griffinlim(
Returns:
Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
"""
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)
assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
assert momentum >= 0, "momentum={} < 0".format(momentum)
# pack batch
shape = specgram.size()
......@@ -255,24 +255,17 @@ def griffinlim(
# initialize the phase
if rand_init:
angles = torch.rand(
specgram.size(),
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
else:
angles = torch.full(
specgram.size(), 1,
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
# And initialize the previous iterate to 0
tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device)
tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device)
for _ in range(n_iter):
# Invert with our current estimate of the phases
inverse = torch.istft(specgram * angles,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=length)
inverse = torch.istft(
specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
)
# Rebuild the spectrogram
rebuilt = torch.stft(
......@@ -282,7 +275,7 @@ def griffinlim(
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
......@@ -298,12 +291,9 @@ def griffinlim(
tprev = rebuilt
# Return the final phase estimates
waveform = torch.istft(specgram * angles,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=length)
waveform = torch.istft(
specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
)
# unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
......@@ -312,11 +302,7 @@ def griffinlim(
def amplitude_to_DB(
x: Tensor,
multiplier: float,
amin: float,
db_multiplier: float,
top_db: Optional[float] = None
x: Tensor, multiplier: float, amin: float, db_multiplier: float, top_db: Optional[float] = None
) -> Tensor:
r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
......@@ -354,11 +340,7 @@ def amplitude_to_DB(
return x_db
def DB_to_amplitude(
x: Tensor,
ref: float,
power: float
) -> Tensor:
def DB_to_amplitude(x: Tensor, ref: float, power: float) -> Tensor:
r"""Turn a tensor from the decibel scale to the power/amplitude scale.
Args:
......@@ -383,7 +365,7 @@ def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
mels (float): Frequency in Mels
"""
if mel_scale not in ['slaney', 'htk']:
if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
......@@ -417,11 +399,11 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
freqs (Tensor): Mels converted in Hz
"""
if mel_scale not in ['slaney', 'htk']:
if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
......@@ -433,15 +415,15 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
log_t = (mels >= min_log_mel)
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
return freqs
def _create_triangular_filterbank(
all_freqs: Tensor,
f_pts: Tensor,
all_freqs: Tensor,
f_pts: Tensor,
) -> Tensor:
"""Create a triangular filter bank.
......@@ -466,13 +448,13 @@ def _create_triangular_filterbank(
def melscale_fbanks(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor:
r"""Create a frequency bin conversion matrix.
......@@ -520,10 +502,10 @@ def melscale_fbanks(
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0)
if (fb.max(dim=0).values == 0.).any():
if (fb.max(dim=0).values == 0.0).any():
warnings.warn(
"At least one mel filterbank has all zero values. "
f"The value for `n_mels` ({n_mels}) may be set too high. "
......@@ -534,11 +516,11 @@ def melscale_fbanks(
def linear_fbanks(
n_freqs: int,
f_min: float,
f_max: float,
n_filter: int,
sample_rate: int,
n_freqs: int,
f_min: float,
f_max: float,
n_filter: int,
sample_rate: int,
) -> Tensor:
r"""Creates a linear triangular filterbank.
......@@ -575,11 +557,7 @@ def linear_fbanks(
return fb
def create_dct(
n_mfcc: int,
n_mels: int,
norm: Optional[str]
) -> Tensor:
def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
normalized depending on norm.
......@@ -605,10 +583,7 @@ def create_dct(
return dct.t()
def mu_law_encoding(
x: Tensor,
quantization_channels: int
) -> Tensor:
def mu_law_encoding(x: Tensor, quantization_channels: int) -> Tensor:
r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -624,8 +599,10 @@ def mu_law_encoding(
"""
mu = quantization_channels - 1.0
if not x.is_floating_point():
warnings.warn("The input Tensor must be of floating type. \
This will be an error in the v0.12 release.")
warnings.warn(
"The input Tensor must be of floating type. \
This will be an error in the v0.12 release."
)
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
......@@ -633,10 +610,7 @@ def mu_law_encoding(
return x_mu
def mu_law_decoding(
x_mu: Tensor,
quantization_channels: int
) -> Tensor:
def mu_law_decoding(x_mu: Tensor, quantization_channels: int) -> Tensor:
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -659,11 +633,7 @@ def mu_law_decoding(
return x
def phase_vocoder(
complex_specgrams: Tensor,
rate: float,
phase_advance: Tensor
) -> Tensor:
def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor) -> Tensor:
r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of ``rate``.
......@@ -699,12 +669,7 @@ def phase_vocoder(
# Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
# Note torch.real is a view so it does not incur any memory copy.
real_dtype = torch.real(complex_specgrams).dtype
time_steps = torch.arange(
0,
complex_specgrams.size(-1),
rate,
device=complex_specgrams.device,
dtype=real_dtype)
time_steps = torch.arange(0, complex_specgrams.size(-1), rate, device=complex_specgrams.device, dtype=real_dtype)
alphas = time_steps % 1.0
phase_0 = complex_specgrams[..., :1].angle()
......@@ -739,12 +704,7 @@ def phase_vocoder(
return complex_specgrams_stretch
def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor:
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
......@@ -760,7 +720,7 @@ def mask_along_axis_iid(
"""
if axis not in [2, 3]:
raise ValueError('Only Frequency and Time masking are supported')
raise ValueError("Only Frequency and Time masking are supported")
device = specgrams.device
dtype = specgrams.dtype
......@@ -781,12 +741,7 @@ def mask_along_axis_iid(
return specgrams
def mask_along_axis(
specgram: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor:
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
......@@ -802,7 +757,7 @@ def mask_along_axis(
Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
"""
if axis not in [1, 2]:
raise ValueError('Only Frequency and Time masking are supported')
raise ValueError("Only Frequency and Time masking are supported")
# pack batch
shape = specgram.size()
......@@ -827,11 +782,7 @@ def mask_along_axis(
return specgram
def compute_deltas(
specgram: Tensor,
win_length: int = 5,
mode: str = "replicate"
) -> Tensor:
def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate") -> Tensor:
r"""Compute delta coefficients of a tensor, usually a spectrogram:
.. math::
......@@ -880,12 +831,7 @@ def compute_deltas(
return output
def _compute_nccf(
waveform: Tensor,
sample_rate: int,
frame_time: float,
freq_low: int
) -> Tensor:
def _compute_nccf(waveform: Tensor, sample_rate: int, frame_time: float, freq_low: int) -> Tensor:
r"""
Compute Normalized Cross-Correlation Function (NCCF).
......@@ -932,25 +878,17 @@ def _compute_nccf(
return nccf
def _combine_max(
a: Tuple[Tensor, Tensor],
b: Tuple[Tensor, Tensor],
thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
def _combine_max(a: Tuple[Tensor, Tensor], b: Tuple[Tensor, Tensor], thresh: float = 0.99) -> Tuple[Tensor, Tensor]:
"""
Take value from first if bigger than a multiplicative factor of the second, elementwise.
"""
mask = (a[0] > thresh * b[0])
mask = a[0] > thresh * b[0]
values = mask * a[0] + ~mask * b[0]
indices = mask * a[1] + ~mask * b[1]
return values, indices
def _find_max_per_frame(
nccf: Tensor,
sample_rate: int,
freq_high: int
) -> Tensor:
def _find_max_per_frame(nccf: Tensor, sample_rate: int, freq_high: int) -> Tensor:
r"""
For each frame, take the highest value of NCCF,
apply centered median smoothing, and convert to frequency.
......@@ -979,10 +917,7 @@ def _find_max_per_frame(
return indices
def _median_smoothing(
indices: Tensor,
win_length: int
) -> Tensor:
def _median_smoothing(indices: Tensor, win_length: int) -> Tensor:
r"""
Apply median smoothing to the 1D tensor over the given window.
"""
......@@ -991,9 +926,7 @@ def _median_smoothing(
pad_length = (win_length - 1) // 2
# "replicate" padding in any dimension
indices = torch.nn.functional.pad(
indices, (pad_length, 0), mode="constant", value=0.
)
indices = torch.nn.functional.pad(indices, (pad_length, 0), mode="constant", value=0.0)
indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
roll = indices.unfold(-1, win_length, 1)
......@@ -1003,12 +936,12 @@ def _median_smoothing(
def detect_pitch_frequency(
waveform: Tensor,
sample_rate: int,
frame_time: float = 10 ** (-2),
win_length: int = 30,
freq_low: int = 85,
freq_high: int = 3400,
waveform: Tensor,
sample_rate: int,
frame_time: float = 10 ** (-2),
win_length: int = 30,
freq_low: int = 85,
freq_high: int = 3400,
) -> Tensor:
r"""Detect pitch frequency.
......@@ -1075,8 +1008,7 @@ def sliding_window_cmn(
last_window_start = last_window_end = -1
cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cmn_specgram = torch.zeros(
num_channels, num_frames, num_feats, dtype=dtype, device=device)
cmn_specgram = torch.zeros(num_channels, num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames):
window_start = 0
window_end = 0
......@@ -1093,12 +1025,12 @@ def sliding_window_cmn(
if window_end > t:
window_end = max(t + 1, min_cmn_window)
if window_end > num_frames:
window_start -= (window_end - num_frames)
window_start -= window_end - num_frames
window_end = num_frames
if window_start < 0:
window_start = 0
if last_window_start == -1:
input_part = specgram[:, window_start: window_end - window_start, :]
input_part = specgram[:, window_start : window_end - window_start, :]
cur_sum += torch.sum(input_part, 1)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
......@@ -1107,24 +1039,23 @@ def sliding_window_cmn(
frame_to_remove = specgram[:, last_window_start, :]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= (frame_to_remove ** 2)
cur_sumsq -= frame_to_remove ** 2
if window_end > last_window_end:
frame_to_add = specgram[:, last_window_end, :]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += (frame_to_add ** 2)
cur_sumsq += frame_to_add ** 2
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
if norm_vars:
if window_frames == 1:
cmn_specgram[:, t, :] = torch.zeros(
num_channels, num_feats, dtype=dtype, device=device)
cmn_specgram[:, t, :] = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance -= (cur_sum ** 2) / (window_frames ** 2)
variance = torch.pow(variance, -0.5)
cmn_specgram[:, t, :] *= variance
......@@ -1135,13 +1066,13 @@ def sliding_window_cmn(
def spectral_centroid(
waveform: Tensor,
sample_rate: int,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
waveform: Tensor,
sample_rate: int,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
) -> Tensor:
r"""
Compute the spectral centroid for each channel along the time axis.
......@@ -1161,10 +1092,17 @@ def spectral_centroid(
Returns:
Tensor: Dimension `(..., time)`
"""
specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, power=1., normalized=False)
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
device=specgram.device).reshape((-1, 1))
specgram = spectrogram(
waveform,
pad=pad,
window=window,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=1.0,
normalized=False,
)
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, device=specgram.device).reshape((-1, 1))
freq_dim = -2
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
......@@ -1201,42 +1139,37 @@ def apply_codec(
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
"""
bytes = io.BytesIO()
torchaudio.backend.sox_io_backend.save(bytes,
waveform,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample
)
torchaudio.backend.sox_io_backend.save(
bytes, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
bytes.seek(0)
augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format
)
return augmented
@_mod_utils.requires_kaldi()
def compute_kaldi_pitch(
waveform: torch.Tensor,
sample_rate: float,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_f0: float = 50,
max_f0: float = 400,
soft_min_f0: float = 10.0,
penalty_factor: float = 0.1,
lowpass_cutoff: float = 1000,
resample_frequency: float = 4000,
delta_pitch: float = 0.005,
nccf_ballast: float = 7000,
lowpass_filter_width: int = 1,
upsample_filter_width: int = 5,
max_frames_latency: int = 0,
frames_per_chunk: int = 0,
simulate_first_pass_online: bool = False,
recompute_frame: int = 500,
snip_edges: bool = True,
waveform: torch.Tensor,
sample_rate: float,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_f0: float = 50,
max_f0: float = 400,
soft_min_f0: float = 10.0,
penalty_factor: float = 0.1,
lowpass_cutoff: float = 1000,
resample_frequency: float = 4000,
delta_pitch: float = 0.005,
nccf_ballast: float = 7000,
lowpass_filter_width: int = 1,
upsample_filter_width: int = 5,
max_frames_latency: int = 0,
frames_per_chunk: int = 0,
simulate_first_pass_online: bool = False,
recompute_frame: int = 500,
snip_edges: bool = True,
) -> torch.Tensor:
"""Extract pitch based on method described in *A pitch extraction algorithm tuned
for automatic speech recognition* [:footcite:`6854049`].
......@@ -1302,11 +1235,24 @@ def compute_kaldi_pitch(
shape = waveform.shape
waveform = waveform.reshape(-1, shape[-1])
result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
waveform, sample_rate, frame_length, frame_shift,
min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
resample_frequency, delta_pitch, nccf_ballast,
lowpass_filter_width, upsample_filter_width, max_frames_latency,
frames_per_chunk, simulate_first_pass_online, recompute_frame,
waveform,
sample_rate,
frame_length,
frame_shift,
min_f0,
max_f0,
soft_min_f0,
penalty_factor,
lowpass_cutoff,
resample_frequency,
delta_pitch,
nccf_ballast,
lowpass_filter_width,
upsample_filter_width,
max_frames_latency,
frames_per_chunk,
simulate_first_pass_online,
recompute_frame,
snip_edges,
)
result = result.reshape(shape[:-1] + result.shape[-2:])
......@@ -1314,15 +1260,16 @@ def compute_kaldi_pitch(
def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None):
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
raise Exception(
......@@ -1334,8 +1281,8 @@ def _get_sinc_resample_kernel(
"For more information, please refer to https://github.com/pytorch/audio/issues/1487."
)
if resampling_method not in ['sinc_interpolation', 'kaiser_window']:
raise ValueError('Invalid resampling method: {}'.format(resampling_method))
if resampling_method not in ["sinc_interpolation", "kaiser_window"]:
raise ValueError("Invalid resampling method: {}".format(resampling_method))
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
......@@ -1381,7 +1328,7 @@ def _get_sinc_resample_kernel(
# we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else:
# kaiser_window
if beta is None:
......@@ -1389,7 +1336,7 @@ def _get_sinc_resample_kernel(
beta_tensor = torch.tensor(float(beta))
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)
......@@ -1401,12 +1348,12 @@ def _get_sinc_resample_kernel(
def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: int,
new_freq: int,
gcd: int,
kernel: Tensor,
width: int,
waveform: Tensor,
orig_freq: int,
new_freq: int,
gcd: int,
kernel: Tensor,
width: int,
):
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
......@@ -1428,13 +1375,13 @@ def _apply_sinc_resample_kernel(
def resample(
waveform: Tensor,
orig_freq: int,
new_freq: int,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
waveform: Tensor,
orig_freq: int,
new_freq: int,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
) -> Tensor:
r"""Resamples the waveform at the new frequency using bandlimited interpolation.
......@@ -1467,8 +1414,17 @@ def resample(
gcd = math.gcd(int(orig_freq), int(new_freq))
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, beta, waveform.device, waveform.dtype)
kernel, width = _get_sinc_resample_kernel(
orig_freq,
new_freq,
gcd,
lowpass_filter_width,
rolloff,
resampling_method,
beta,
waveform.device,
waveform.dtype,
)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
......@@ -1557,25 +1513,24 @@ def pitch_shift(
ori_len = shape[-1]
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
spec_f = torch.stft(input=waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
spec_f = torch.stft(
input=waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None]
spec_stretch = phase_vocoder(spec_f, rate, phase_advance)
len_stretch = int(round(ori_len / rate))
waveform_stretch = torch.istft(spec_stretch,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=len_stretch)
waveform_stretch = torch.istft(
spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
shift_len = waveform_shift.size()[-1]
if shift_len > ori_len:
......@@ -1617,7 +1572,7 @@ def rnnt_loss(
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size `(batch)`,
otherwise scalar.
"""
if reduction not in ['none', 'mean', 'sum']:
if reduction not in ["none", "mean", "sum"]:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if blank < 0: # reinterpret blank index if blank < 0.
......@@ -1632,9 +1587,9 @@ def rnnt_loss(
clamp=clamp,
)
if reduction == 'mean':
if reduction == "mean":
return costs.mean()
elif reduction == 'sum':
elif reduction == "sum":
return costs.sum()
return costs
......@@ -7,23 +7,23 @@ import torch
from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils
if _mod_utils.is_module_available('kaldi_io', 'numpy'):
import numpy as np
if _mod_utils.is_module_available("kaldi_io", "numpy"):
import kaldi_io
import numpy as np
__all__ = [
'read_vec_int_ark',
'read_vec_flt_scp',
'read_vec_flt_ark',
'read_mat_scp',
'read_mat_ark',
"read_vec_int_ark",
"read_vec_flt_scp",
"read_vec_flt_ark",
"read_mat_scp",
"read_mat_ark",
]
def _convert_method_output_to_tensor(file_or_fd: Any,
fn: Callable,
convert_contiguous: bool = False) -> Iterable[Tuple[str, Tensor]]:
def _convert_method_output_to_tensor(
file_or_fd: Any, fn: Callable, convert_contiguous: bool = False
) -> Iterable[Tuple[str, Tensor]]:
r"""Takes a method invokes it. The output is converted to a tensor.
Args:
......@@ -42,7 +42,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
yield key, torch.from_numpy(np_arr)
@_mod_utils.requires_module('kaldi_io', 'numpy')
@_mod_utils.requires_module("kaldi_io", "numpy")
def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
......@@ -62,7 +62,7 @@ def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_int_ark, convert_contiguous=True)
@_mod_utils.requires_module('kaldi_io', 'numpy')
@_mod_utils.requires_module("kaldi_io", "numpy")
def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<float32/float64>) tuples, read according to Kaldi scp.
......@@ -79,7 +79,7 @@ def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_scp)
@_mod_utils.requires_module('kaldi_io', 'numpy')
@_mod_utils.requires_module("kaldi_io", "numpy")
def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<float32/float64>) tuples, which reads from the ark file/stream.
......@@ -96,7 +96,7 @@ def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_ark)
@_mod_utils.requires_module('kaldi_io', 'numpy')
@_mod_utils.requires_module("kaldi_io", "numpy")
def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,matrix<float32/float64>) tuples, read according to Kaldi scp.
......@@ -113,7 +113,7 @@ def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_scp)
@_mod_utils.requires_module('kaldi_io', 'numpy')
@_mod_utils.requires_module("kaldi_io", "numpy")
def read_mat_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,matrix<float32/float64>) tuples, which reads from the ark file/stream.
......
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_model,
......@@ -13,19 +12,20 @@ from .wav2vec2 import (
hubert_large,
hubert_xlarge,
)
from .wavernn import WaveRNN
__all__ = [
'Wav2Letter',
'WaveRNN',
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
'hubert_base',
'hubert_large',
'hubert_xlarge',
'Tacotron2',
"Wav2Letter",
"WaveRNN",
"ConvTasNet",
"DeepSpeech",
"Wav2Vec2Model",
"wav2vec2_model",
"wav2vec2_base",
"wav2vec2_large",
"wav2vec2_large_lv60k",
"hubert_base",
"hubert_large",
"hubert_xlarge",
"Tacotron2",
]
......@@ -35,9 +35,7 @@ class ConvBlock(torch.nn.Module):
super().__init__()
self.conv_layers = torch.nn.Sequential(
torch.nn.Conv1d(
in_channels=io_channels, out_channels=hidden_channels, kernel_size=1
),
torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
torch.nn.PReLU(),
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
torch.nn.Conv1d(
......@@ -55,17 +53,11 @@ class ConvBlock(torch.nn.Module):
self.res_out = (
None
if no_residual
else torch.nn.Conv1d(
in_channels=hidden_channels, out_channels=io_channels, kernel_size=1
)
)
self.skip_out = torch.nn.Conv1d(
in_channels=hidden_channels, out_channels=io_channels, kernel_size=1
else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
)
self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
def forward(
self, input: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
feature = self.conv_layers(input)
if self.res_out is None:
residual = None
......@@ -110,12 +102,8 @@ class MaskGenerator(torch.nn.Module):
self.input_dim = input_dim
self.num_sources = num_sources
self.input_norm = torch.nn.GroupNorm(
num_groups=1, num_channels=input_dim, eps=1e-8
)
self.input_conv = torch.nn.Conv1d(
in_channels=input_dim, out_channels=num_feats, kernel_size=1
)
self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
self.receptive_field = 0
self.conv_layers = torch.nn.ModuleList([])
......@@ -133,12 +121,12 @@ class MaskGenerator(torch.nn.Module):
no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
)
)
self.receptive_field += (
kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
)
self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
self.output_prelu = torch.nn.PReLU()
self.output_conv = torch.nn.Conv1d(
in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1,
in_channels=num_feats,
out_channels=input_dim * num_sources,
kernel_size=1,
)
if msk_activate == "sigmoid":
self.mask_activate = torch.nn.Sigmoid()
......@@ -239,9 +227,7 @@ class ConvTasNet(torch.nn.Module):
bias=False,
)
def _align_num_frames_with_strides(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, int]:
def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Pad input Tensor so that the end of the input tensor corresponds with
1. (if kernel size is odd) the center of the last convolution kernel
......@@ -294,9 +280,7 @@ class ConvTasNet(torch.nn.Module):
Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
"""
if input.ndim != 3 or input.shape[1] != 1:
raise ValueError(
f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}"
)
raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
# B: batch size
# L: input frame length
......@@ -309,13 +293,9 @@ class ConvTasNet(torch.nn.Module):
batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
feats = self.encoder(padded) # B, F, M
masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
masked = masked.view(
batch_size * self.num_sources, self.enc_num_feats, -1
) # B*S, F, M
masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
decoded = self.decoder(masked) # B*S, 1, L'
output = decoded.view(
batch_size, self.num_sources, num_padded_frames
) # B, S, L'
output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
if num_pads > 0:
output = output[..., :-num_pads] # B, S, L
return output
......@@ -10,11 +10,7 @@ class FullyConnected(torch.nn.Module):
n_hidden: Internal hidden unit size.
"""
def __init__(self,
n_feature: int,
n_hidden: int,
dropout: float,
relu_max_clip: int = 20) -> None:
def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
super(FullyConnected, self).__init__()
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
self.relu_max_clip = relu_max_clip
......@@ -52,9 +48,7 @@ class DeepSpeech(torch.nn.Module):
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
self.bi_rnn = torch.nn.RNN(
n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True
)
self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
self.out = torch.nn.Linear(n_hidden, n_class)
......@@ -78,7 +72,7 @@ class DeepSpeech(torch.nn.Module):
# T x N x H
x, _ = self.bi_rnn(x)
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:]
x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
# T x N x H
x = self.fc4(x)
# T x N x H
......
......@@ -29,8 +29,8 @@ import warnings
from typing import Tuple, List, Optional, Union
import torch
from torch import nn
from torch import Tensor
from torch import nn
from torch.nn import functional as F
......@@ -39,9 +39,7 @@ __all__ = [
]
def _get_linear_layer(
in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear"
) -> torch.nn.Linear:
def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear:
r"""Linear layer with xavier uniform initialization.
Args:
......@@ -55,9 +53,7 @@ def _get_linear_layer(
(torch.nn.Linear): The corresponding linear layer.
"""
linear = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
return linear
......@@ -101,9 +97,7 @@ def _get_conv1d_layer(
bias=bias,
)
torch.nn.init.xavier_uniform_(
conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
return conv1d
......@@ -194,9 +188,7 @@ class _Attention(nn.Module):
attention_location_kernel_size: int,
) -> None:
super().__init__()
self.query_layer = _get_linear_layer(
attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
)
self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh")
self.memory_layer = _get_linear_layer(
encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
)
......@@ -208,9 +200,7 @@ class _Attention(nn.Module):
)
self.score_mask_value = -float("inf")
def _get_alignment_energies(
self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor
) -> Tensor:
def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
r"""Get the alignment vector.
Args:
......@@ -226,9 +216,7 @@ class _Attention(nn.Module):
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
alignment = energies.squeeze(2)
return alignment
......@@ -256,9 +244,7 @@ class _Attention(nn.Module):
attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
"""
alignment = self._get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
alignment = alignment.masked_fill(mask, self.score_mask_value)
......@@ -281,10 +267,7 @@ class _Prenet(nn.Module):
super().__init__()
in_sizes = [in_dim] + out_sizes[:-1]
self.layers = nn.ModuleList(
[
_get_linear_layer(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, out_sizes)
]
[_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)]
)
def forward(self, x: Tensor) -> Tensor:
......@@ -488,9 +471,7 @@ class _Decoder(nn.Module):
self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim])
self.attention_rnn = nn.LSTMCell(
prenet_dim + encoder_embedding_dim, attention_rnn_dim
)
self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
self.attention_layer = _Attention(
attention_rnn_dim,
......@@ -500,13 +481,9 @@ class _Decoder(nn.Module):
attention_location_kernel_size,
)
self.decoder_rnn = nn.LSTMCell(
attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True
)
self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True)
self.linear_projection = _get_linear_layer(
decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step
)
self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step)
self.gate_layer = _get_linear_layer(
decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid"
......@@ -526,9 +503,7 @@ class _Decoder(nn.Module):
n_batch = memory.size(0)
dtype = memory.dtype
device = memory.device
decoder_input = torch.zeros(
n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device
)
decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
return decoder_input
def _initialize_decoder_states(
......@@ -557,27 +532,15 @@ class _Decoder(nn.Module):
dtype = memory.dtype
device = memory.device
attention_hidden = torch.zeros(
n_batch, self.attention_rnn_dim, dtype=dtype, device=device
)
attention_cell = torch.zeros(
n_batch, self.attention_rnn_dim, dtype=dtype, device=device
)
attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
decoder_hidden = torch.zeros(
n_batch, self.decoder_rnn_dim, dtype=dtype, device=device
)
decoder_cell = torch.zeros(
n_batch, self.decoder_rnn_dim, dtype=dtype, device=device
)
decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
attention_weights_cum = torch.zeros(
n_batch, max_time, dtype=dtype, device=device
)
attention_context = torch.zeros(
n_batch, self.encoder_embedding_dim, dtype=dtype, device=device
)
attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device)
processed_memory = self.attention_layer.memory_layer(memory)
......@@ -688,16 +651,10 @@ class _Decoder(nn.Module):
"""
cell_input = torch.cat((decoder_input, attention_context), -1)
attention_hidden, attention_cell = self.attention_rnn(
cell_input, (attention_hidden, attention_cell)
)
attention_hidden = F.dropout(
attention_hidden, self.attention_dropout, self.training
)
attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell))
attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training)
attention_weights_cat = torch.cat(
(attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1
)
attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1)
attention_context, attention_weights = self.attention_layer(
attention_hidden, memory, processed_memory, attention_weights_cat, mask
)
......@@ -705,14 +662,10 @@ class _Decoder(nn.Module):
attention_weights_cum += attention_weights
decoder_input = torch.cat((attention_hidden, attention_context), -1)
decoder_hidden, decoder_cell = self.decoder_rnn(
decoder_input, (decoder_hidden, decoder_cell)
)
decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat(
(decoder_hidden, attention_context), dim=1
)
decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1)
decoder_output = self.linear_projection(decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
......@@ -819,15 +772,11 @@ class _Decoder(nn.Module):
n_batch = memory.size(0)
dtype = memory.dtype
device = memory.device
decoder_input = torch.zeros(
n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device
)
decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
return decoder_input
@torch.jit.export
def infer(self,
memory: Tensor,
memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Decoder inference
Args:
......@@ -905,16 +854,14 @@ class _Decoder(nn.Module):
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn(
"Reached max decoder steps. The generated spectrogram might not cover "
"the whole transcript.")
"Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
)
mel_specgrams = torch.cat(mel_specgrams, dim=0)
gate_outputs = torch.cat(gate_outputs, dim=0)
alignments = torch.cat(alignments, dim=0)
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(
mel_specgrams, gate_outputs, alignments
)
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
......@@ -984,9 +931,7 @@ class Tacotron2(nn.Module):
self.n_frames_per_step = n_frames_per_step
self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight)
self.encoder = _Encoder(
encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size
)
self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size)
self.decoder = _Decoder(
n_mels,
n_frames_per_step,
......@@ -1003,9 +948,7 @@ class Tacotron2(nn.Module):
prenet_dim,
gate_threshold,
)
self.postnet = _Postnet(
n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution
)
self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution)
def forward(
self,
......@@ -1094,9 +1037,7 @@ class Tacotron2(nn.Module):
embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(
encoder_outputs, lengths
)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
mel_outputs_postnet = self.postnet(mel_specgram)
mel_outputs_postnet = mel_specgram + mel_outputs_postnet
......
......@@ -19,9 +19,7 @@ class Wav2Letter(nn.Module):
num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
"""
def __init__(self, num_classes: int = 40,
input_type: str = "waveform",
num_features: int = 1) -> None:
def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
super(Wav2Letter, self).__init__()
acoustic_num_features = 250 if input_type == "waveform" else num_features
......@@ -47,13 +45,13 @@ class Wav2Letter(nn.Module):
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
)
if input_type == "waveform":
waveform_model = nn.Sequential(
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
)
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
......
from . import utils
from .model import (
Wav2Vec2Model,
wav2vec2_model,
......@@ -8,16 +9,15 @@ from .model import (
hubert_large,
hubert_xlarge,
)
from . import utils
__all__ = [
'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
'hubert_base',
'hubert_large',
'hubert_xlarge',
'utils',
"Wav2Vec2Model",
"wav2vec2_model",
"wav2vec2_base",
"wav2vec2_large",
"wav2vec2_large_lv60k",
"hubert_base",
"hubert_large",
"hubert_xlarge",
"utils",
]
......@@ -10,24 +10,25 @@ _LG = logging.getLogger(__name__)
class LayerNorm(nn.LayerNorm):
"""Layer norm with transpose"""
def forward(self, input: Tensor) -> Tensor:
x = input.transpose(-2, -1)
x = nn.functional.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps)
x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.transpose(-2, -1)
return x
class ConvLayerBlock(Module):
"""Convolution unit of FeatureExtractor"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
bias: bool,
layer_norm: Optional[Module],
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
bias: bool,
layer_norm: Optional[Module],
):
super().__init__()
self.kernel_size = kernel_size
......@@ -42,9 +43,9 @@ class ConvLayerBlock(Module):
)
def forward(
self,
x: Tensor,
length: Optional[Tensor],
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
......@@ -60,7 +61,7 @@ class ConvLayerBlock(Module):
x = nn.functional.gelu(x)
if length is not None:
length = torch.div(length - self.kernel_size, self.stride, rounding_mode='floor') + 1
length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
# When input length is 0, the resulting length can be negative. So fix it here.
length = torch.max(torch.zeros_like(length), length)
return x, length
......@@ -73,17 +74,18 @@ class FeatureExtractor(Module):
conv_layers (nn.ModuleList):
convolution layers
"""
def __init__(
self,
conv_layers: nn.ModuleList,
self,
conv_layers: nn.ModuleList,
):
super().__init__()
self.conv_layers = conv_layers
def forward(
self,
x: Tensor,
length: Optional[Tensor],
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
......@@ -100,9 +102,7 @@ class FeatureExtractor(Module):
Valid length of each output sample. shape: ``[batch, ]``.
"""
if x.ndim != 2:
raise ValueError(
"Expected the input Tensor to be 2D (batch, time), "
"but received {list(x.shape)}")
raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}")
x = x.unsqueeze(1) # (batch, channel==1, frame)
for layer in self.conv_layers:
......@@ -121,15 +121,19 @@ class FeatureProjection(Module):
out_features (int): Output feature dim.
dropout (float): Dropout probability.
"""
def __init__(
self,
in_features: int,
out_features: int,
dropout: float,
self,
in_features: int,
out_features: int,
dropout: float,
):
super().__init__()
self.layer_norm = nn.LayerNorm(in_features)
self.projection = nn.Linear(in_features, out_features,)
self.projection = nn.Linear(
in_features,
out_features,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
......@@ -154,11 +158,12 @@ class ConvolutionalPositionalEmbedding(Module):
kernel_size (int): The number of frames to be use.
groups (int): The number of groups in feature dimensions.
"""
def __init__(
self,
embed_dim: int,
kernel_size: int,
groups: int,
self,
embed_dim: int,
kernel_size: int,
groups: int,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -178,11 +183,8 @@ class ConvolutionalPositionalEmbedding(Module):
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == 'torch.nn.utils.weight_norm' and
hook.__class__.__name__ == 'WeightNorm'
):
_LG.warning('Removing weight_norm from %s', self.__class__.__name__)
if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
_LG.warning("Removing weight_norm from %s", self.__class__.__name__)
torch.nn.utils.remove_weight_norm(self.conv)
return self
......@@ -197,7 +199,7 @@ class ConvolutionalPositionalEmbedding(Module):
x = x.transpose(-2, -1)
x = self.conv(x)
if self.num_remove > 0:
x = x[..., :-self.num_remove]
x = x[..., : -self.num_remove]
x = torch.nn.functional.gelu(x)
x = x.transpose(-2, -1)
return x
......@@ -212,11 +214,12 @@ class SelfAttention(Module):
dropout (float, optional):
Dropout probabiliry on attn_output_weights. Default: ``0.0``
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
):
super().__init__()
head_dim = embed_dim // num_heads
......@@ -236,9 +239,9 @@ class SelfAttention(Module):
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
......@@ -251,17 +254,13 @@ class SelfAttention(Module):
"""
if x.ndim != 3 or x.shape[2] != self.embed_dim:
raise ValueError(
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). "
f"Found {x.shape}."
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
)
batch_size, length, embed_dim = x.size()
if attention_mask is not None:
shape_ = (batch_size, 1, length, length)
if attention_mask.size() != shape_:
raise ValueError(
f"The expected attention mask shape is {shape_}. "
f"Found {attention_mask.size()}."
)
raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.")
shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
......@@ -283,14 +282,14 @@ class SelfAttention(Module):
class FeedForward(Module):
"""Layer that follows attention layer in encoder layer.
"""
"""Layer that follows attention layer in encoder layer."""
def __init__(
self,
io_features: int,
intermediate_features: int,
intermediate_dropout: float,
output_dropout: float,
self,
io_features: int,
intermediate_features: int,
intermediate_dropout: float,
output_dropout: float,
):
super().__init__()
self.intermediate_dense = nn.Linear(io_features, intermediate_features)
......@@ -315,14 +314,14 @@ class FeedForward(Module):
class EncoderLayer(Module):
"""A layer unit in encoder. Combines multihead self attention and feed forward.
"""
"""A layer unit in encoder. Combines multihead self attention and feed forward."""
def __init__(
self,
attention: Module,
dropout: float,
layer_norm_first: bool,
feed_forward: Module,
self,
attention: Module,
dropout: float,
layer_norm_first: bool,
feed_forward: Module,
):
super().__init__()
self.attention = attention
......@@ -333,9 +332,9 @@ class EncoderLayer(Module):
self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
"""
Args:
......@@ -362,12 +361,12 @@ class EncoderLayer(Module):
class Transformer(Module):
def __init__(
self,
pos_conv_embed: Module,
dropout: float,
layers: Module,
layer_norm_first: bool,
layer_drop: float,
self,
pos_conv_embed: Module,
dropout: float,
layers: Module,
layer_norm_first: bool,
layer_drop: float,
):
super().__init__()
self.pos_conv_embed = pos_conv_embed
......@@ -387,9 +386,9 @@ class Transformer(Module):
return x
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
x = self._preprocess(x)
for layer in self.layers:
......@@ -402,14 +401,14 @@ class Transformer(Module):
return x
def get_intermediate_outputs(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None,
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
if num_layers is not None:
if not 0 < num_layers <= len(self.layers):
raise ValueError(f'`num_layers` must be between [1, {len(self.layers)}]')
raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
ret: List[Tensor] = []
x = self._preprocess(x)
......@@ -423,18 +422,18 @@ class Transformer(Module):
class Encoder(Module):
def __init__(
self,
feature_projection: Module,
transformer: Module,
self,
feature_projection: Module,
transformer: Module,
):
super().__init__()
self.feature_projection = feature_projection
self.transformer = transformer
def _preprocess(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
x = self.feature_projection(features)
......@@ -450,30 +449,29 @@ class Encoder(Module):
return x, mask
def forward(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask)
return x
def extract_features(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
self,
features: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
x, masks = self._preprocess(features, lengths)
return self.transformer.get_intermediate_outputs(
x, attention_mask=masks, num_layers=num_layers)
return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
################################################################################
def _get_feature_extractor(
norm_mode: str,
shapes: List[Tuple[int, int, int]],
bias: bool,
norm_mode: str,
shapes: List[Tuple[int, int, int]],
bias: bool,
) -> FeatureExtractor:
"""
Args:
......@@ -545,19 +543,19 @@ def _get_feature_extractor(
def _get_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
num_heads: int,
attention_dropout: float,
ff_interm_features: int,
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
num_heads: int,
attention_dropout: float,
ff_interm_features: int,
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
) -> Encoder:
"""
Args:
......
......@@ -26,11 +26,12 @@ class Wav2Vec2Model(Module):
aux (torch.nn.Module or None, optional):
Auxiliary module. If provided, the output from encoder is passed to this module.
""" # noqa: E501
def __init__(
self,
feature_extractor: Module,
encoder: Module,
aux: Optional[Module] = None,
self,
feature_extractor: Module,
encoder: Module,
aux: Optional[Module] = None,
):
super().__init__()
self.feature_extractor = feature_extractor
......@@ -39,10 +40,10 @@ class Wav2Vec2Model(Module):
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
"""Extract feature vectors from raw waveforms
......@@ -81,9 +82,9 @@ class Wav2Vec2Model(Module):
return x, lengths
def forward(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Compute the sequence of probability distribution over labels.
......@@ -117,22 +118,22 @@ class Wav2Vec2Model(Module):
def wav2vec2_model(
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_attention_dropout: float,
encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
aux_num_out: Optional[int],
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_attention_dropout: float,
encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
aux_num_out: Optional[int],
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, aux_num_out: Optional[int]) -> torchaudio.models.Wav2Vec2Model
......@@ -262,7 +263,8 @@ def wav2vec2_model(
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias)
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
)
encoder = components._get_encoder(
in_features=extractor_conv_layer_config[-1][0],
embed_dim=encoder_embed_dim,
......@@ -285,12 +287,12 @@ def wav2vec2_model(
def wav2vec2_base(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
......@@ -336,12 +338,12 @@ def wav2vec2_base(
def wav2vec2_large(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_large(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
......@@ -387,12 +389,12 @@ def wav2vec2_large(
def wav2vec2_large_lv60k(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_large_lv60k( encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
......@@ -438,12 +440,12 @@ def wav2vec2_large_lv60k(
def hubert_base(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.05,
aux_num_out: Optional[int] = None,
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.05,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
......@@ -469,7 +471,7 @@ def hubert_base(
The resulting model.
""" # noqa: E501
return wav2vec2_model(
extractor_mode='group_norm',
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
......@@ -489,12 +491,12 @@ def hubert_base(
def hubert_large(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
......@@ -520,7 +522,7 @@ def hubert_large(
The resulting model.
""" # noqa: E501
return wav2vec2_model(
extractor_mode='layer_norm',
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
......@@ -540,12 +542,12 @@ def hubert_large(
def hubert_xlarge(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
......@@ -571,7 +573,7 @@ def hubert_xlarge(
The resulting model.
""" # noqa: E501
return wav2vec2_model(
extractor_mode='layer_norm',
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1280,
......
from .import_huggingface import import_huggingface_model
from .import_fairseq import import_fairseq_model
from .import_huggingface import import_huggingface_model
__all__ = [
'import_huggingface_model',
'import_fairseq_model',
"import_huggingface_model",
"import_fairseq_model",
]
......@@ -13,11 +13,11 @@ def _parse_config(w2v_model):
encoder = w2v_model.encoder
conv_layers = w2v_model.feature_extractor.conv_layers
extractor_mode = 'layer_norm'
if 'GroupNorm' in conv_layers[0][2].__class__.__name__:
extractor_mode = 'group_norm'
extractor_mode = "layer_norm"
if "GroupNorm" in conv_layers[0][2].__class__.__name__:
extractor_mode = "group_norm"
else:
extractor_mode = 'layer_norm'
extractor_mode = "layer_norm"
conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers]
......@@ -26,53 +26,52 @@ def _parse_config(w2v_model):
elif all(l[0].bias is not None for l in conv_layers):
conv_bias = True
else:
raise ValueError(
'Either all the convolutions layers have bias term or none of them should.')
raise ValueError("Either all the convolutions layers have bias term or none of them should.")
config = {
'extractor_mode': extractor_mode,
'extractor_conv_layer_config': conv_layer_config,
'extractor_conv_bias': conv_bias,
'encoder_embed_dim': w2v_model.post_extract_proj.out_features,
'encoder_projection_dropout': w2v_model.dropout_input.p,
'encoder_pos_conv_kernel': encoder.pos_conv[0].kernel_size[0],
'encoder_pos_conv_groups': encoder.pos_conv[0].groups,
'encoder_num_layers': len(encoder.layers),
'encoder_num_heads': encoder.layers[0].self_attn.num_heads,
'encoder_attention_dropout': encoder.layers[0].self_attn.dropout_module.p,
'encoder_ff_interm_features': encoder.layers[0].fc1.out_features,
'encoder_ff_interm_dropout': encoder.layers[0].dropout2.p,
'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop,
"extractor_mode": extractor_mode,
"extractor_conv_layer_config": conv_layer_config,
"extractor_conv_bias": conv_bias,
"encoder_embed_dim": w2v_model.post_extract_proj.out_features,
"encoder_projection_dropout": w2v_model.dropout_input.p,
"encoder_pos_conv_kernel": encoder.pos_conv[0].kernel_size[0],
"encoder_pos_conv_groups": encoder.pos_conv[0].groups,
"encoder_num_layers": len(encoder.layers),
"encoder_num_heads": encoder.layers[0].self_attn.num_heads,
"encoder_attention_dropout": encoder.layers[0].self_attn.dropout_module.p,
"encoder_ff_interm_features": encoder.layers[0].fc1.out_features,
"encoder_ff_interm_dropout": encoder.layers[0].dropout2.p,
"encoder_dropout": encoder.layers[0].dropout3.p,
"encoder_layer_norm_first": encoder.layer_norm_first,
"encoder_layer_drop": encoder.layerdrop,
}
return config
def _map_key(key):
key_ = key
if key.startswith('w2v_model.'):
key = key.replace('w2v_model.', '')
if re.match(r'(mask_emb|quantizer|project_q|final_proj|mask_emb)', key):
if key.startswith("w2v_model."):
key = key.replace("w2v_model.", "")
if re.match(r"(mask_emb|quantizer|project_q|final_proj|mask_emb)", key):
return None
# Feature Extractor
# Group norm when "extractor_mode" is "default".
# (Only the first layer)
# "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight"
# "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias"
match = re.match(r'feature_extractor\.conv_layers\.0\.2\.(weight|bias)', key)
match = re.match(r"feature_extractor\.conv_layers\.0\.2\.(weight|bias)", key)
if match:
return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}"
# Convolutions
# "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight"
# "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias"
match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)', key)
match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)", key)
if match:
return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}"
# Layer norm when "extractor_mode" is "layer_norm".
# "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight"
# "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias"
match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)', key)
match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)", key)
if match:
return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}"
match = re.match(r"post_extract_proj\.(weight|bias)", key)
......@@ -111,9 +110,9 @@ def _map_key(key):
if match:
return f"aux.{match.group(1)}"
# HuBERT Extension
if key in ['label_embs_concat']:
if key in ["label_embs_concat"]:
return key
raise ValueError(f'Unexpected key: {key_}')
raise ValueError(f"Unexpected key: {key_}")
def _convert_state_dict(state_dict):
......@@ -179,16 +178,15 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
.. _fairseq: https://github.com/pytorch/fairseq
"""
class_ = original.__class__.__name__
if class_ == 'Wav2Vec2Model':
if class_ == "Wav2Vec2Model":
return _import_wav2vec2_pretraining(original)
if class_ == 'Wav2VecEncoder':
if class_ == "Wav2VecEncoder":
return _import_wav2vec2_finetuning(original)
if class_ == 'HubertModel':
if class_ == "HubertModel":
return _import_hubert_pretraining(original)
if class_ == 'HubertEncoder':
if class_ == "HubertEncoder":
return _import_hubert_finetuning(original)
raise ValueError(
f'Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}')
raise ValueError(f"Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}")
def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model:
......
......@@ -11,40 +11,38 @@ _LG = logging.getLogger(__name__)
def _get_config(cfg):
config = {
'extractor_mode': f'{cfg.feat_extract_norm}_norm',
'extractor_conv_layer_config': list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
'extractor_conv_bias': cfg.conv_bias,
'encoder_embed_dim': cfg.hidden_size,
'encoder_projection_dropout': cfg.feat_proj_dropout,
'encoder_pos_conv_kernel': cfg.num_conv_pos_embeddings,
'encoder_pos_conv_groups': cfg.num_conv_pos_embedding_groups,
'encoder_num_layers': cfg.num_hidden_layers,
'encoder_num_heads': cfg.num_attention_heads,
'encoder_attention_dropout': cfg.attention_dropout,
'encoder_ff_interm_features': cfg.intermediate_size,
'encoder_ff_interm_dropout': cfg.activation_dropout,
'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop,
"extractor_mode": f"{cfg.feat_extract_norm}_norm",
"extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
"extractor_conv_bias": cfg.conv_bias,
"encoder_embed_dim": cfg.hidden_size,
"encoder_projection_dropout": cfg.feat_proj_dropout,
"encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
"encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
"encoder_num_layers": cfg.num_hidden_layers,
"encoder_num_heads": cfg.num_attention_heads,
"encoder_attention_dropout": cfg.attention_dropout,
"encoder_ff_interm_features": cfg.intermediate_size,
"encoder_ff_interm_dropout": cfg.activation_dropout,
"encoder_dropout": cfg.hidden_dropout,
"encoder_layer_norm_first": cfg.do_stable_layer_norm,
"encoder_layer_drop": cfg.layerdrop,
}
return config
def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
if original.__class__.__name__ == "Wav2Vec2ForCTC":
aux_num_out = original.config.vocab_size
wav2vec2 = original.wav2vec2
else:
_LG.warning(
'The model is not an instance of Wav2Vec2ForCTC. '
'"lm_head" module is not imported.')
_LG.warning("The model is not an instance of Wav2Vec2ForCTC. " '"lm_head" module is not imported.')
aux_num_out = None
wav2vec2 = original
imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
if original.__class__.__name__ == "Wav2Vec2ForCTC":
imported.aux.load_state_dict(original.lm_head.state_dict())
return imported
......@@ -71,10 +69,10 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
.. _Transformers: https://huggingface.co/transformers/
"""
_LG.info('Importing model.')
_LG.info('Loading model configuration.')
_LG.info("Importing model.")
_LG.info("Loading model configuration.")
config = _get_config(original.config)
_LG.debug(' - config: %s', config)
_LG.info('Building model.')
_LG.debug(" - config: %s", config)
_LG.info("Building model.")
imported = _build(config, original)
return imported
from typing import List, Tuple, Optional
import math
from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import torch.nn.functional as F
__all__ = [
"ResBlock",
......@@ -35,7 +35,7 @@ class ResBlock(nn.Module):
nn.BatchNorm1d(n_freq),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
nn.BatchNorm1d(n_freq)
nn.BatchNorm1d(n_freq),
)
def forward(self, specgram: Tensor) -> Tensor:
......@@ -66,12 +66,9 @@ class MelResNet(nn.Module):
>>> output = melresnet(input) # shape: (10, 128, 508)
"""
def __init__(self,
n_res_block: int = 10,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
kernel_size: int = 5) -> None:
def __init__(
self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
) -> None:
super().__init__()
ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
......@@ -81,7 +78,7 @@ class MelResNet(nn.Module):
nn.BatchNorm1d(n_hidden),
nn.ReLU(inplace=True),
*ResBlocks,
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1)
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
)
def forward(self, specgram: Tensor) -> Tensor:
......@@ -110,9 +107,7 @@ class Stretch2d(nn.Module):
>>> output = stretch2d(input) # shape: (10, 500, 5120)
"""
def __init__(self,
time_scale: int,
freq_scale: int) -> None:
def __init__(self, time_scale: int, freq_scale: int) -> None:
super().__init__()
self.freq_scale = freq_scale
......@@ -148,13 +143,15 @@ class UpsampleNetwork(nn.Module):
>>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128)
"""
def __init__(self,
upsample_scales: List[int],
n_res_block: int = 10,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
kernel_size: int = 5) -> None:
def __init__(
self,
upsample_scales: List[int],
n_res_block: int = 10,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
kernel_size: int = 5,
) -> None:
super().__init__()
total_scale = 1
......@@ -169,12 +166,10 @@ class UpsampleNetwork(nn.Module):
up_layers = []
for scale in upsample_scales:
stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(in_channels=1,
out_channels=1,
kernel_size=(1, scale * 2 + 1),
padding=(0, scale),
bias=False)
torch.nn.init.constant_(conv.weight, 1. / (scale * 2 + 1))
conv = nn.Conv2d(
in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
)
torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
up_layers.append(stretch)
up_layers.append(conv)
self.upsample_layers = nn.Sequential(*up_layers)
......@@ -197,7 +192,7 @@ class UpsampleNetwork(nn.Module):
specgram = specgram.unsqueeze(1)
upsampling_output = self.upsample_layers(specgram)
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
return upsampling_output, resnet_output
......@@ -230,17 +225,19 @@ class WaveRNN(nn.Module):
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
"""
def __init__(self,
upsample_scales: List[int],
n_classes: int,
hop_length: int,
n_res_block: int = 10,
n_rnn: int = 512,
n_fc: int = 512,
kernel_size: int = 5,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128) -> None:
def __init__(
self,
upsample_scales: List[int],
n_classes: int,
hop_length: int,
n_res_block: int = 10,
n_rnn: int = 512,
n_fc: int = 512,
kernel_size: int = 5,
n_freq: int = 128,
n_hidden: int = 128,
n_output: int = 128,
) -> None:
super().__init__()
self.kernel_size = kernel_size
......@@ -257,12 +254,7 @@ class WaveRNN(nn.Module):
if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
self.upsample = UpsampleNetwork(upsample_scales,
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
......@@ -286,8 +278,8 @@ class WaveRNN(nn.Module):
Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
"""
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
assert specgram.size(1) == 1, 'Require the input channel of specgram is 1'
assert waveform.size(1) == 1, "Require the input channel of waveform is 1"
assert specgram.size(1) == 1, "Require the input channel of specgram is 1"
# remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
......@@ -302,10 +294,10 @@ class WaveRNN(nn.Module):
aux = aux.transpose(1, 2)
aux_idx = [self.n_aux * i for i in range(5)]
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
x = self.fc(x)
......@@ -375,7 +367,7 @@ class WaveRNN(nn.Module):
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
aux_split = [aux[:, self.n_aux * i: self.n_aux * (i + 1), :] for i in range(4)]
aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
for i in range(seq_len):
......
from ._tts import (
Tacotron2TTSBundle,
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
from ._wav2vec2.impl import (
Wav2Vec2Bundle,
Wav2Vec2ASRBundle,
......@@ -25,43 +32,36 @@ from ._wav2vec2.impl import (
HUBERT_ASR_LARGE,
HUBERT_ASR_XLARGE,
)
from ._tts import (
Tacotron2TTSBundle,
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
__all__ = [
'Wav2Vec2Bundle',
'Wav2Vec2ASRBundle',
'WAV2VEC2_BASE',
'WAV2VEC2_LARGE',
'WAV2VEC2_LARGE_LV60K',
'WAV2VEC2_ASR_BASE_10M',
'WAV2VEC2_ASR_BASE_100H',
'WAV2VEC2_ASR_BASE_960H',
'WAV2VEC2_ASR_LARGE_10M',
'WAV2VEC2_ASR_LARGE_100H',
'WAV2VEC2_ASR_LARGE_960H',
'WAV2VEC2_ASR_LARGE_LV60K_10M',
'WAV2VEC2_ASR_LARGE_LV60K_100H',
'WAV2VEC2_ASR_LARGE_LV60K_960H',
'WAV2VEC2_XLSR53',
'VOXPOPULI_ASR_BASE_10K_EN',
'VOXPOPULI_ASR_BASE_10K_ES',
'VOXPOPULI_ASR_BASE_10K_DE',
'VOXPOPULI_ASR_BASE_10K_FR',
'VOXPOPULI_ASR_BASE_10K_IT',
'HUBERT_BASE',
'HUBERT_LARGE',
'HUBERT_XLARGE',
'HUBERT_ASR_LARGE',
'HUBERT_ASR_XLARGE',
'Tacotron2TTSBundle',
'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH',
'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH',
'TACOTRON2_WAVERNN_CHAR_LJSPEECH',
'TACOTRON2_WAVERNN_PHONE_LJSPEECH',
"Wav2Vec2Bundle",
"Wav2Vec2ASRBundle",
"WAV2VEC2_BASE",
"WAV2VEC2_LARGE",
"WAV2VEC2_LARGE_LV60K",
"WAV2VEC2_ASR_BASE_10M",
"WAV2VEC2_ASR_BASE_100H",
"WAV2VEC2_ASR_BASE_960H",
"WAV2VEC2_ASR_LARGE_10M",
"WAV2VEC2_ASR_LARGE_100H",
"WAV2VEC2_ASR_LARGE_960H",
"WAV2VEC2_ASR_LARGE_LV60K_10M",
"WAV2VEC2_ASR_LARGE_LV60K_100H",
"WAV2VEC2_ASR_LARGE_LV60K_960H",
"WAV2VEC2_XLSR53",
"VOXPOPULI_ASR_BASE_10K_EN",
"VOXPOPULI_ASR_BASE_10K_ES",
"VOXPOPULI_ASR_BASE_10K_DE",
"VOXPOPULI_ASR_BASE_10K_FR",
"VOXPOPULI_ASR_BASE_10K_IT",
"HUBERT_BASE",
"HUBERT_LARGE",
"HUBERT_XLARGE",
"HUBERT_ASR_LARGE",
"HUBERT_ASR_XLARGE",
"Tacotron2TTSBundle",
"TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
"TACOTRON2_WAVERNN_CHAR_LJSPEECH",
"TACOTRON2_WAVERNN_PHONE_LJSPEECH",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment