Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
...@@ -4,17 +4,16 @@ import os ...@@ -4,17 +4,16 @@ import os
import tarfile import tarfile
import urllib import urllib
import urllib.request import urllib.request
import zipfile
import warnings import warnings
import zipfile
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
def stream_url(url: str, def stream_url(
start_byte: Optional[int] = None, url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
block_size: int = 32 * 1024, ) -> Iterable:
progress_bar: bool = True) -> Iterable:
"""Stream url by chunk """Stream url by chunk
Args: Args:
...@@ -36,11 +35,11 @@ def stream_url(url: str, ...@@ -36,11 +35,11 @@ def stream_url(url: str,
req.headers["Range"] = "bytes={}-".format(start_byte) req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm( with urllib.request.urlopen(req) as upointer, tqdm(
unit="B", unit="B",
unit_scale=True, unit_scale=True,
unit_divisor=1024, unit_divisor=1024,
total=url_size, total=url_size,
disable=not progress_bar, disable=not progress_bar,
) as pbar: ) as pbar:
num_bytes = 0 num_bytes = 0
...@@ -53,13 +52,15 @@ def stream_url(url: str, ...@@ -53,13 +52,15 @@ def stream_url(url: str,
pbar.update(len(chunk)) pbar.update(len(chunk))
def download_url(url: str, def download_url(
download_folder: str, url: str,
filename: Optional[str] = None, download_folder: str,
hash_value: Optional[str] = None, filename: Optional[str] = None,
hash_type: str = "sha256", hash_value: Optional[str] = None,
progress_bar: bool = True, hash_type: str = "sha256",
resume: bool = False) -> None: progress_bar: bool = True,
resume: bool = False,
) -> None:
"""Download file to disk. """Download file to disk.
Args: Args:
...@@ -84,9 +85,7 @@ def download_url(url: str, ...@@ -84,9 +85,7 @@ def download_url(url: str,
local_size: Optional[int] = os.path.getsize(filepath) local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath): elif not resume and os.path.exists(filepath):
raise RuntimeError( raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
"{} already exists. Delete the file manually and retry.".format(filepath)
)
else: else:
mode = "wb" mode = "wb"
local_size = None local_size = None
...@@ -95,11 +94,7 @@ def download_url(url: str, ...@@ -95,11 +94,7 @@ def download_url(url: str,
with open(filepath, "rb") as file_obj: with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type): if validate_file(file_obj, hash_value, hash_type):
return return
raise RuntimeError( raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
with open(filepath, mode) as fpointer: with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
...@@ -107,11 +102,7 @@ def download_url(url: str, ...@@ -107,11 +102,7 @@ def download_url(url: str,
with open(filepath, "rb") as file_obj: with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type): if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError( raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
......
import os import os
from typing import Tuple from typing import Tuple
import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
_CHECKSUMS = { _CHECKSUMS = {
"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "f96258be9fdc2cbff6559541aae7ea4f59df3fcaf5cf963aae5ca647357e359c"
"f96258be9fdc2cbff6559541aae7ea4f59df3fcaf5cf963aae5ca647357e359c"
} }
...@@ -41,17 +38,15 @@ class VCTK_092(Dataset): ...@@ -41,17 +38,15 @@ class VCTK_092(Dataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
mic_id: str = "mic2", mic_id: str = "mic2",
download: bool = False, download: bool = False,
url: str = URL, url: str = URL,
audio_ext=".flac", audio_ext=".flac",
): ):
if mic_id not in ["mic1", "mic2"]: if mic_id not in ["mic1", "mic2"]:
raise RuntimeError( raise RuntimeError(f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}')
f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}'
)
archive = os.path.join(root, "VCTK-Corpus-0.92.zip") archive = os.path.join(root, "VCTK-Corpus-0.92.zip")
...@@ -69,9 +64,7 @@ class VCTK_092(Dataset): ...@@ -69,9 +64,7 @@ class VCTK_092(Dataset):
extract_archive(archive, self._path) extract_archive(archive, self._path)
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
raise RuntimeError( raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
"Dataset not found. Please use `download=True` to download it."
)
# Extracting speaker IDs from the folder structure # Extracting speaker IDs from the folder structure
self._speaker_ids = sorted(os.listdir(self._txt_dir)) self._speaker_ids = sorted(os.listdir(self._txt_dir))
...@@ -91,9 +84,7 @@ class VCTK_092(Dataset): ...@@ -91,9 +84,7 @@ class VCTK_092(Dataset):
if speaker_id == "p280" and mic_id == "mic2": if speaker_id == "p280" and mic_id == "mic2":
continue continue
utterance_dir = os.path.join(self._txt_dir, speaker_id) utterance_dir = os.path.join(self._txt_dir, speaker_id)
for utterance_file in sorted( for utterance_file in sorted(f for f in os.listdir(utterance_dir) if f.endswith(".txt")):
f for f in os.listdir(utterance_dir) if f.endswith(".txt")
):
utterance_id = os.path.splitext(utterance_file)[0] utterance_id = os.path.splitext(utterance_file)[0]
audio_path_mic = os.path.join( audio_path_mic = os.path.join(
self._audio_dir, self._audio_dir,
...@@ -112,9 +103,7 @@ class VCTK_092(Dataset): ...@@ -112,9 +103,7 @@ class VCTK_092(Dataset):
return torchaudio.load(file_path) return torchaudio.load(file_path)
def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType: def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType:
transcript_path = os.path.join( transcript_path = os.path.join(self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt")
self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt"
)
audio_path = os.path.join( audio_path = os.path.join(
self._audio_dir, self._audio_dir,
speaker_id, speaker_id,
......
...@@ -2,11 +2,10 @@ import os ...@@ -2,11 +2,10 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset
from torch.hub import download_url_to_file from torch.hub import download_url_to_file
from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
extract_archive, extract_archive,
) )
...@@ -39,7 +38,7 @@ class YESNO(Dataset): ...@@ -39,7 +38,7 @@ class YESNO(Dataset):
root: Union[str, Path], root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"], url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False download: bool = False,
) -> None: ) -> None:
self._parse_filesystem(root, url, folder_in_archive, download) self._parse_filesystem(root, url, folder_in_archive, download)
...@@ -58,9 +57,7 @@ class YESNO(Dataset): ...@@ -58,9 +57,7 @@ class YESNO(Dataset):
extract_archive(archive) extract_archive(archive)
if not os.path.isdir(self._path): if not os.path.isdir(self._path):
raise RuntimeError( raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
"Dataset not found. Please use `download=True` to download it."
)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav")) self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
......
from .filtering import (
allpass_biquad,
band_biquad,
bandpass_biquad,
bandreject_biquad,
bass_biquad,
biquad,
contrast,
dither,
dcshift,
deemph_biquad,
equalizer_biquad,
filtfilt,
flanger,
gain,
highpass_biquad,
lfilter,
lowpass_biquad,
overdrive,
phaser,
riaa_biquad,
treble_biquad,
vad,
)
from .functional import ( from .functional import (
amplitude_to_DB, amplitude_to_DB,
compute_deltas, compute_deltas,
...@@ -23,75 +47,51 @@ from .functional import ( ...@@ -23,75 +47,51 @@ from .functional import (
pitch_shift, pitch_shift,
rnnt_loss, rnnt_loss,
) )
from .filtering import (
allpass_biquad,
band_biquad,
bandpass_biquad,
bandreject_biquad,
bass_biquad,
biquad,
contrast,
dither,
dcshift,
deemph_biquad,
equalizer_biquad,
filtfilt,
flanger,
gain,
highpass_biquad,
lfilter,
lowpass_biquad,
overdrive,
phaser,
riaa_biquad,
treble_biquad,
vad,
)
__all__ = [ __all__ = [
'amplitude_to_DB', "amplitude_to_DB",
'compute_deltas', "compute_deltas",
'compute_kaldi_pitch', "compute_kaldi_pitch",
'create_dct', "create_dct",
'melscale_fbanks', "melscale_fbanks",
'linear_fbanks', "linear_fbanks",
'DB_to_amplitude', "DB_to_amplitude",
'detect_pitch_frequency', "detect_pitch_frequency",
'griffinlim', "griffinlim",
'mask_along_axis', "mask_along_axis",
'mask_along_axis_iid', "mask_along_axis_iid",
'mu_law_encoding', "mu_law_encoding",
'mu_law_decoding', "mu_law_decoding",
'phase_vocoder', "phase_vocoder",
'sliding_window_cmn', "sliding_window_cmn",
'spectrogram', "spectrogram",
'inverse_spectrogram', "inverse_spectrogram",
'spectral_centroid', "spectral_centroid",
'allpass_biquad', "allpass_biquad",
'band_biquad', "band_biquad",
'bandpass_biquad', "bandpass_biquad",
'bandreject_biquad', "bandreject_biquad",
'bass_biquad', "bass_biquad",
'biquad', "biquad",
'contrast', "contrast",
'dither', "dither",
'dcshift', "dcshift",
'deemph_biquad', "deemph_biquad",
'equalizer_biquad', "equalizer_biquad",
'filtfilt', "filtfilt",
'flanger', "flanger",
'gain', "gain",
'highpass_biquad', "highpass_biquad",
'lfilter', "lfilter",
'lowpass_biquad', "lowpass_biquad",
'overdrive', "overdrive",
'phaser', "phaser",
'riaa_biquad', "riaa_biquad",
'treble_biquad', "treble_biquad",
'vad', "vad",
'apply_codec', "apply_codec",
'resample', "resample",
'edit_distance', "edit_distance",
'pitch_shift', "pitch_shift",
'rnnt_loss', "rnnt_loss",
] ]
...@@ -45,7 +45,7 @@ def _generate_wave_table( ...@@ -45,7 +45,7 @@ def _generate_wave_table(
d = (torch.sin(point.to(torch.float64) / table_size * 2 * math.pi) + 1) / 2 d = (torch.sin(point.to(torch.float64) / table_size * 2 * math.pi) + 1) / 2
elif wave_type == "TRIANGLE": elif wave_type == "TRIANGLE":
d = point.to(torch.float64) * 2 / table_size d = point.to(torch.float64) * 2 / table_size
value = torch.div(4 * point, table_size, rounding_mode='floor') value = torch.div(4 * point, table_size, rounding_mode="floor")
d[value == 0] = d[value == 0] + 0.5 d[value == 0] = d[value == 0] + 0.5
d[value == 1] = 1.5 - d[value == 1] d[value == 1] = 1.5 - d[value == 1]
d[value == 2] = 1.5 - d[value == 2] d[value == 2] = 1.5 - d[value == 2]
...@@ -64,9 +64,7 @@ def _generate_wave_table( ...@@ -64,9 +64,7 @@ def _generate_wave_table(
return d return d
def allpass_biquad( def allpass_biquad(waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707) -> Tensor:
waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707
) -> Tensor:
r"""Design two-pole all-pass filter. Similar to SoX implementation. r"""Design two-pole all-pass filter. Similar to SoX implementation.
Args: Args:
...@@ -191,9 +189,7 @@ def bandpass_biquad( ...@@ -191,9 +189,7 @@ def bandpass_biquad(
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def bandreject_biquad( def bandreject_biquad(waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707) -> Tensor:
waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707
) -> Tensor:
r"""Design two-pole band-reject filter. Similar to SoX implementation. r"""Design two-pole band-reject filter. Similar to SoX implementation.
Args: Args:
...@@ -273,9 +269,7 @@ def bass_biquad( ...@@ -273,9 +269,7 @@ def bass_biquad(
return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0) return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0)
def biquad( def biquad(waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: float, a2: float) -> Tensor:
waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: float, a2: float
) -> Tensor:
r"""Perform a biquad filter of input tensor. Initial conditions set to 0. r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
https://en.wikipedia.org/wiki/Digital_biquad_filter https://en.wikipedia.org/wiki/Digital_biquad_filter
...@@ -339,9 +333,7 @@ def contrast(waveform: Tensor, enhancement_amount: float = 75.0) -> Tensor: ...@@ -339,9 +333,7 @@ def contrast(waveform: Tensor, enhancement_amount: float = 75.0) -> Tensor:
return output_waveform return output_waveform
def dcshift( def dcshift(waveform: Tensor, shift: float, limiter_gain: Optional[float] = None) -> Tensor:
waveform: Tensor, shift: float, limiter_gain: Optional[float] = None
) -> Tensor:
r"""Apply a DC shift to the audio. Similar to SoX implementation. r"""Apply a DC shift to the audio. Similar to SoX implementation.
This can be useful to remove a DC offset This can be useful to remove a DC offset
(caused perhaps by a hardware problem in the recording chain) from the audio (caused perhaps by a hardware problem in the recording chain) from the audio
...@@ -367,25 +359,13 @@ def dcshift( ...@@ -367,25 +359,13 @@ def dcshift(
if limiter_gain is not None and shift > 0: if limiter_gain is not None and shift > 0:
mask = waveform > limiter_threshold mask = waveform > limiter_threshold
temp = ( temp = (waveform[mask] - limiter_threshold) * limiter_gain / (1 - limiter_threshold)
(waveform[mask] - limiter_threshold) output_waveform[mask] = (temp + limiter_threshold + shift).clamp(max=limiter_threshold)
* limiter_gain
/ (1 - limiter_threshold)
)
output_waveform[mask] = (temp + limiter_threshold + shift).clamp(
max=limiter_threshold
)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1) output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
elif limiter_gain is not None and shift < 0: elif limiter_gain is not None and shift < 0:
mask = waveform < -limiter_threshold mask = waveform < -limiter_threshold
temp = ( temp = (waveform[mask] + limiter_threshold) * limiter_gain / (1 - limiter_threshold)
(waveform[mask] + limiter_threshold) output_waveform[mask] = (temp - limiter_threshold + shift).clamp(min=-limiter_threshold)
* limiter_gain
/ (1 - limiter_threshold)
)
output_waveform[mask] = (temp - limiter_threshold + shift).clamp(
min=-limiter_threshold
)
output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1) output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1)
else: else:
output_waveform = (waveform + shift).clamp(min=-1, max=1) output_waveform = (waveform + shift).clamp(min=-1, max=1)
...@@ -461,9 +441,7 @@ def _add_noise_shaping(dithered_waveform: Tensor, waveform: Tensor) -> Tensor: ...@@ -461,9 +441,7 @@ def _add_noise_shaping(dithered_waveform: Tensor, waveform: Tensor) -> Tensor:
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:]) return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution( def _apply_probability_distribution(waveform: Tensor, density_function: str = "TPDF") -> Tensor:
waveform: Tensor, density_function: str = "TPDF"
) -> Tensor:
r"""Apply a probability distribution function on a waveform. r"""Apply a probability distribution function on a waveform.
Triangular probability density function (TPDF) dither noise has a Triangular probability density function (TPDF) dither noise has a
...@@ -561,9 +539,7 @@ def _apply_probability_distribution( ...@@ -561,9 +539,7 @@ def _apply_probability_distribution(
signal_scaled_dis = signal_scaled + gaussian signal_scaled_dis = signal_scaled + gaussian
else: else:
# dtype needed for https://github.com/pytorch/pytorch/issues/32358 # dtype needed for https://github.com/pytorch/pytorch/issues/32358
TPDF = torch.bartlett_window( TPDF = torch.bartlett_window(time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device)
time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device
)
TPDF = TPDF.repeat((channel_size + 1), 1) TPDF = TPDF.repeat((channel_size + 1), 1)
signal_scaled_dis = signal_scaled + TPDF signal_scaled_dis = signal_scaled + TPDF
...@@ -574,9 +550,7 @@ def _apply_probability_distribution( ...@@ -574,9 +550,7 @@ def _apply_probability_distribution(
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
def dither( def dither(waveform: Tensor, density_function: str = "TPDF", noise_shaping: bool = False) -> Tensor:
waveform: Tensor, density_function: str = "TPDF", noise_shaping: bool = False
) -> Tensor:
r"""Dither increases the perceived dynamic range of audio stored at a r"""Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization). (i.e. adding minimally perceived noise to mask distortion caused by quantization).
...@@ -594,9 +568,7 @@ def dither( ...@@ -594,9 +568,7 @@ def dither(
Returns: Returns:
Tensor: waveform dithered Tensor: waveform dithered
""" """
dithered = _apply_probability_distribution( dithered = _apply_probability_distribution(waveform, density_function=density_function)
waveform, density_function=density_function
)
if noise_shaping: if noise_shaping:
return _add_noise_shaping(dithered, waveform) return _add_noise_shaping(dithered, waveform)
...@@ -643,7 +615,10 @@ def equalizer_biquad( ...@@ -643,7 +615,10 @@ def equalizer_biquad(
def filtfilt( def filtfilt(
waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor: ) -> Tensor:
r"""Apply an IIR filter forward and backward to a waveform. r"""Apply an IIR filter forward and backward to a waveform.
...@@ -667,7 +642,11 @@ def filtfilt( ...@@ -667,7 +642,11 @@ def filtfilt(
""" """
forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True) forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
backward_filtered = lfilter( backward_filtered = lfilter(
forward_filtered.flip(-1), a_coeffs, b_coeffs, clamp=clamp, batching=True, forward_filtered.flip(-1),
a_coeffs,
b_coeffs,
clamp=clamp,
batching=True,
).flip(-1) ).flip(-1)
return backward_filtered return backward_filtered
...@@ -757,9 +736,7 @@ def flanger( ...@@ -757,9 +736,7 @@ def flanger(
delay_buf_length = int((delay_min + delay_depth) * sample_rate + 0.5) delay_buf_length = int((delay_min + delay_depth) * sample_rate + 0.5)
delay_buf_length = delay_buf_length + 2 delay_buf_length = delay_buf_length + 2
delay_bufs = torch.zeros( delay_bufs = torch.zeros(waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device)
waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device
)
delay_last = torch.zeros(waveform.shape[0], n_channels, dtype=dtype, device=device) delay_last = torch.zeros(waveform.shape[0], n_channels, dtype=dtype, device=device)
lfo_length = int(sample_rate / speed) lfo_length = int(sample_rate / speed)
...@@ -787,9 +764,7 @@ def flanger( ...@@ -787,9 +764,7 @@ def flanger(
delay_buf_pos = (delay_buf_pos + delay_buf_length - 1) % delay_buf_length delay_buf_pos = (delay_buf_pos + delay_buf_length - 1) % delay_buf_length
cur_channel_phase = (channel_idxs * lfo_length * channel_phase + 0.5).to( cur_channel_phase = (channel_idxs * lfo_length * channel_phase + 0.5).to(torch.int64)
torch.int64
)
delay_tensor = lfo[(lfo_pos + cur_channel_phase) % lfo_length] delay_tensor = lfo[(lfo_pos + cur_channel_phase) % lfo_length]
frac_delay = torch.frac(delay_tensor) frac_delay = torch.frac(delay_tensor)
delay_tensor = torch.floor(delay_tensor) delay_tensor = torch.floor(delay_tensor)
...@@ -800,24 +775,18 @@ def flanger( ...@@ -800,24 +775,18 @@ def flanger(
delay_bufs[:, :, delay_buf_pos] = temp + delay_last * feedback_gain delay_bufs[:, :, delay_buf_pos] = temp + delay_last * feedback_gain
delayed_0 = delay_bufs[ delayed_0 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length
]
int_delay = int_delay + 1 int_delay = int_delay + 1
delayed_1 = delay_bufs[ delayed_1 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length
]
int_delay = int_delay + 1 int_delay = int_delay + 1
if interpolation == "linear": if interpolation == "linear":
delayed = delayed_0 + (delayed_1 - delayed_0) * frac_delay delayed = delayed_0 + (delayed_1 - delayed_0) * frac_delay
else: else:
delayed_2 = delay_bufs[ delayed_2 = delay_bufs[:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length]
:, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length
]
int_delay = int_delay + 1 int_delay = int_delay + 1
...@@ -854,9 +823,7 @@ def gain(waveform: Tensor, gain_db: float = 1.0) -> Tensor: ...@@ -854,9 +823,7 @@ def gain(waveform: Tensor, gain_db: float = 1.0) -> Tensor:
return waveform * ratio return waveform * ratio
def highpass_biquad( def highpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707
) -> Tensor:
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args: Args:
...@@ -889,9 +856,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T ...@@ -889,9 +856,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
n_order = a_coeffs_flipped.size(1) n_order = a_coeffs_flipped.size(1)
a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2) a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2)
for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)): for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)):
windowed_output_signal = padded_output_waveform[ windowed_output_signal = padded_output_waveform[:, :, i_sample : i_sample + n_order]
:, :, i_sample:i_sample + n_order
]
o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t() o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t()
padded_output_waveform[:, :, i_sample + n_order - 1] = o0 padded_output_waveform[:, :, i_sample + n_order - 1] = o0
...@@ -899,7 +864,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T ...@@ -899,7 +864,7 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
try: try:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
except RuntimeError as err: except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter_core_loop' assert str(err) == "No such operator torchaudio::_lfilter_core_loop"
_lfilter_core_cpu_loop = _lfilter_core_generic_loop _lfilter_core_cpu_loop = _lfilter_core_generic_loop
...@@ -929,40 +894,32 @@ def _lfilter_core( ...@@ -929,40 +894,32 @@ def _lfilter_core(
b_coeffs_flipped = b_coeffs.flip(1) b_coeffs_flipped = b_coeffs.flip(1)
# calculate windowed_input_signal in parallel using convolution # calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d( input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)
padded_waveform,
b_coeffs_flipped.unsqueeze(1),
groups=n_channel
)
input_signal_windows.div_(a_coeffs[:, :1]) input_signal_windows.div_(a_coeffs[:, :1])
a_coeffs_flipped.div_(a_coeffs[:, :1]) a_coeffs_flipped.div_(a_coeffs[:, :1])
if input_signal_windows.device == torch.device('cpu') and\ if (
a_coeffs_flipped.device == torch.device('cpu') and\ input_signal_windows.device == torch.device("cpu")
padded_output_waveform.device == torch.device('cpu'): and a_coeffs_flipped.device == torch.device("cpu")
and padded_output_waveform.device == torch.device("cpu")
):
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) _lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
else: else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, :, n_order - 1:] output = padded_output_waveform[:, :, n_order - 1 :]
return output return output
try: try:
_lfilter = torch.ops.torchaudio._lfilter _lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err: except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter' assert str(err) == "No such operator torchaudio::_lfilter"
_lfilter = _lfilter_core _lfilter = _lfilter_core
def lfilter( def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
batching: bool = True
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation. r"""Perform an IIR filter by evaluating difference equation.
Note: Note:
...@@ -1016,9 +973,7 @@ def lfilter( ...@@ -1016,9 +973,7 @@ def lfilter(
return output return output
def lowpass_biquad( def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707
) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args: Args:
...@@ -1048,11 +1003,7 @@ def lowpass_biquad( ...@@ -1048,11 +1003,7 @@ def lowpass_biquad(
def _overdrive_core_loop_generic( def _overdrive_core_loop_generic(
waveform: Tensor, waveform: Tensor, temp: Tensor, last_in: Tensor, last_out: Tensor, output_waveform: Tensor
temp: Tensor,
last_in: Tensor,
last_out: Tensor,
output_waveform: Tensor
): ):
for i in range(waveform.shape[-1]): for i in range(waveform.shape[-1]):
last_out = temp[:, i] - last_in + 0.995 * last_out last_out = temp[:, i] - last_in + 0.995 * last_out
...@@ -1063,7 +1014,7 @@ def _overdrive_core_loop_generic( ...@@ -1063,7 +1014,7 @@ def _overdrive_core_loop_generic(
try: try:
_overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop
except RuntimeError as err: except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_overdrive_core_loop' assert str(err) == "No such operator torchaudio::_overdrive_core_loop"
_overdrive_core_loop_cpu = _overdrive_core_loop_generic _overdrive_core_loop_cpu = _overdrive_core_loop_generic
...@@ -1110,7 +1061,7 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor: ...@@ -1110,7 +1061,7 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)
# Uses CPU optimized loop function if available for CPU device # Uses CPU optimized loop function if available for CPU device
if device == torch.device('cpu'): if device == torch.device("cpu"):
_overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform) _overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform)
else: else:
_overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform) _overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform)
...@@ -1164,9 +1115,7 @@ def phaser( ...@@ -1164,9 +1115,7 @@ def phaser(
waveform = waveform.view(-1, actual_shape[-1]) waveform = waveform.view(-1, actual_shape[-1])
delay_buf_len = int((delay_ms * 0.001 * sample_rate) + 0.5) delay_buf_len = int((delay_ms * 0.001 * sample_rate) + 0.5)
delay_buf = torch.zeros( delay_buf = torch.zeros(waveform.shape[0], delay_buf_len, dtype=dtype, device=device)
waveform.shape[0], delay_buf_len, dtype=dtype, device=device
)
mod_buf_len = int(sample_rate / mod_speed + 0.5) mod_buf_len = int(sample_rate / mod_speed + 0.5)
...@@ -1203,9 +1152,7 @@ def phaser( ...@@ -1203,9 +1152,7 @@ def phaser(
delay_buf_list[delay_pos] = temp * decay delay_buf_list[delay_pos] = temp * decay
output_waveform_pre_gain_list.append(temp) output_waveform_pre_gain_list.append(temp)
output_waveform = torch.stack(output_waveform_pre_gain_list, dim=1).to( output_waveform = torch.stack(output_waveform_pre_gain_list, dim=1).to(dtype=dtype, device=device)
dtype=dtype, device=device
)
output_waveform.mul_(gain_out) output_waveform.mul_(gain_out)
return output_waveform.clamp(min=-1, max=1).view(actual_shape) return output_waveform.clamp(min=-1, max=1).view(actual_shape)
...@@ -1344,9 +1291,7 @@ def _measure( ...@@ -1344,9 +1291,7 @@ def _measure(
dftBuf = torch.zeros(dft_len_ws) dftBuf = torch.zeros(dft_len_ws)
_index_ns = torch.tensor( _index_ns = torch.tensor([index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)])
[index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)]
)
dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws] dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws]
# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf)); # memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
...@@ -1358,9 +1303,7 @@ def _measure( ...@@ -1358,9 +1303,7 @@ def _measure(
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf)); # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_() _dftBuf[:spectrum_start].zero_()
mult: float = ( mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
)
_d = _dftBuf[spectrum_start:spectrum_end].abs() _d = _dftBuf[spectrum_start:spectrum_end].abs()
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
...@@ -1387,17 +1330,13 @@ def _measure( ...@@ -1387,17 +1330,13 @@ def _measure(
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1) _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window _cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_() _cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_()
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf); # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf = torch.fft.rfft(_cepstrum_Buf) _cepstrum_Buf = torch.fft.rfft(_cepstrum_Buf)
result: float = float( result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)) result = math.log(result / (cepstrum_end - cepstrum_start)) if result > 0 else -math.inf
)
result = (
math.log(result / (cepstrum_end - cepstrum_start)) if result > 0 else -math.inf
)
return max(0, 21 + result) return max(0, 21 + result)
...@@ -1489,9 +1428,7 @@ def vad( ...@@ -1489,9 +1428,7 @@ def vad(
" and https://github.com/pytorch/audio/issues/1468." " and https://github.com/pytorch/audio/issues/1468."
) )
measure_duration: float = ( measure_duration: float = 2.0 / measure_freq if measure_duration is None else measure_duration
2.0 / measure_freq if measure_duration is None else measure_duration
)
measure_len_ws = int(sample_rate * measure_duration + 0.5) measure_len_ws = int(sample_rate * measure_duration + 0.5)
measure_len_ns = measure_len_ws measure_len_ns = measure_len_ws
...@@ -1506,9 +1443,7 @@ def vad( ...@@ -1506,9 +1443,7 @@ def vad(
gap_len = int(allowed_gap * measure_freq + 0.5) gap_len = int(allowed_gap * measure_freq + 0.5)
fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5) fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5)
samplesLen_ns = ( samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
)
spectrum_window = torch.zeros(measure_len_ws) spectrum_window = torch.zeros(measure_len_ws)
for i in range(measure_len_ws): for i in range(measure_len_ws):
...@@ -1526,9 +1461,7 @@ def vad( ...@@ -1526,9 +1461,7 @@ def vad(
for i in range(spectrum_end - spectrum_start): for i in range(spectrum_end - spectrum_start):
cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start) cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start)
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start)); # lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
cepstrum_window *= torch.hann_window( cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)
spectrum_end - spectrum_start, dtype=torch.float
)
cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq) cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq)
cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq) cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
...@@ -1567,9 +1500,7 @@ def vad( ...@@ -1567,9 +1500,7 @@ def vad(
samples[i, samplesIndex_ns] = waveform[i, pos] samples[i, samplesIndex_ns] = waveform[i, pos]
# if (!p->measure_timer_ns) { # if (!p->measure_timer_ns) {
if measure_timer_ns == 0: if measure_timer_ns == 0:
index_ns: int = ( index_ns: int = (samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
samplesIndex_ns + samplesLen_ns - measure_len_ns
) % samplesLen_ns
meas: float = _measure( meas: float = _measure(
measure_len_ws=measure_len_ws, measure_len_ws=measure_len_ws,
samples=samples[i], samples=samples[i],
...@@ -1589,9 +1520,7 @@ def vad( ...@@ -1589,9 +1520,7 @@ def vad(
boot_count=boot_count, boot_count=boot_count,
) )
measures[i, measures_index] = meas measures[i, measures_index] = meas
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * ( mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
1.0 - trigger_meas_time_mult
)
has_triggered = has_triggered or (mean_meas[i] >= trigger_level) has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
if has_triggered: if has_triggered:
...@@ -1602,9 +1531,7 @@ def vad( ...@@ -1602,9 +1531,7 @@ def vad(
j: int = 0 j: int = 0
for j in range(n): for j in range(n):
if (measures[i, k] >= trigger_level) and ( if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
j <= jTrigger + gap_len
):
jZero = jTrigger = j jZero = jTrigger = j
elif (measures[i, k] == 0) and (jTrigger >= jZero): elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j jZero = j
...@@ -1631,6 +1558,6 @@ def vad( ...@@ -1631,6 +1558,6 @@ def vad(
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns
res = waveform[:, pos - samplesLen_ns + flushedLen_ns:] res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
# unpack batch # unpack batch
return res.view(shape[:-1] + res.shape[-1:]) return res.view(shape[:-1] + res.shape[-1:])
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from collections.abc import Sequence
import io import io
import math import math
import warnings import warnings
from collections.abc import Sequence
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torchaudio
from torch import Tensor from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
import torchaudio
__all__ = [ __all__ = [
"spectrogram", "spectrogram",
...@@ -28,9 +28,9 @@ __all__ = [ ...@@ -28,9 +28,9 @@ __all__ = [
"mu_law_encoding", "mu_law_encoding",
"mu_law_decoding", "mu_law_decoding",
"phase_vocoder", "phase_vocoder",
'mask_along_axis', "mask_along_axis",
'mask_along_axis_iid', "mask_along_axis_iid",
'sliding_window_cmn', "sliding_window_cmn",
"spectral_centroid", "spectral_centroid",
"apply_codec", "apply_codec",
"resample", "resample",
...@@ -41,18 +41,18 @@ __all__ = [ ...@@ -41,18 +41,18 @@ __all__ = [
def spectrogram( def spectrogram(
waveform: Tensor, waveform: Tensor,
pad: int, pad: int,
window: Tensor, window: Tensor,
n_fft: int, n_fft: int,
hop_length: int, hop_length: int,
win_length: int, win_length: int,
power: Optional[float], power: Optional[float],
normalized: bool, normalized: bool,
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
return_complex: Optional[bool] = None, return_complex: Optional[bool] = None,
) -> Tensor: ) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex. The spectrogram can be either magnitude-only or complex.
...@@ -116,7 +116,7 @@ def spectrogram( ...@@ -116,7 +116,7 @@ def spectrogram(
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:]) spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
if normalized: if normalized:
spec_f /= window.pow(2.).sum().sqrt() spec_f /= window.pow(2.0).sum().sqrt()
if power is not None: if power is not None:
if power == 1.0: if power == 1.0:
return spec_f.abs() return spec_f.abs()
...@@ -125,17 +125,17 @@ def spectrogram( ...@@ -125,17 +125,17 @@ def spectrogram(
def inverse_spectrogram( def inverse_spectrogram(
spectrogram: Tensor, spectrogram: Tensor,
length: Optional[int], length: Optional[int],
pad: int, pad: int,
window: Tensor, window: Tensor,
n_fft: int, n_fft: int,
hop_length: int, hop_length: int,
win_length: int, win_length: int,
normalized: bool, normalized: bool,
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
) -> Tensor: ) -> Tensor:
r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided
complex-valued spectrogram. complex-valued spectrogram.
...@@ -166,7 +166,7 @@ def inverse_spectrogram( ...@@ -166,7 +166,7 @@ def inverse_spectrogram(
raise ValueError("Expected `spectrogram` to be complex dtype.") raise ValueError("Expected `spectrogram` to be complex dtype.")
if normalized: if normalized:
spectrogram = spectrogram * window.pow(2.).sum().sqrt() spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
# pack batch # pack batch
shape = spectrogram.size() shape = spectrogram.size()
...@@ -203,20 +203,20 @@ def _get_complex_dtype(real_dtype: torch.dtype): ...@@ -203,20 +203,20 @@ def _get_complex_dtype(real_dtype: torch.dtype):
return torch.cfloat return torch.cfloat
if real_dtype == torch.half: if real_dtype == torch.half:
return torch.complex32 return torch.complex32
raise ValueError(f'Unexpected dtype {real_dtype}') raise ValueError(f"Unexpected dtype {real_dtype}")
def griffinlim( def griffinlim(
specgram: Tensor, specgram: Tensor,
window: Tensor, window: Tensor,
n_fft: int, n_fft: int,
hop_length: int, hop_length: int,
win_length: int, win_length: int,
power: float, power: float,
n_iter: int, n_iter: int,
momentum: float, momentum: float,
length: Optional[int], length: Optional[int],
rand_init: bool rand_init: bool,
) -> Tensor: ) -> Tensor:
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
...@@ -244,8 +244,8 @@ def griffinlim( ...@@ -244,8 +244,8 @@ def griffinlim(
Returns: Returns:
Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given. Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
""" """
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum) assert momentum >= 0, "momentum={} < 0".format(momentum)
# pack batch # pack batch
shape = specgram.size() shape = specgram.size()
...@@ -255,24 +255,17 @@ def griffinlim( ...@@ -255,24 +255,17 @@ def griffinlim(
# initialize the phase # initialize the phase
if rand_init: if rand_init:
angles = torch.rand( angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
specgram.size(),
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
else: else:
angles = torch.full( angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
specgram.size(), 1,
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
# And initialize the previous iterate to 0 # And initialize the previous iterate to 0
tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device) tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device)
for _ in range(n_iter): for _ in range(n_iter):
# Invert with our current estimate of the phases # Invert with our current estimate of the phases
inverse = torch.istft(specgram * angles, inverse = torch.istft(
n_fft=n_fft, specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
hop_length=hop_length, )
win_length=win_length,
window=window,
length=length)
# Rebuild the spectrogram # Rebuild the spectrogram
rebuilt = torch.stft( rebuilt = torch.stft(
...@@ -282,7 +275,7 @@ def griffinlim( ...@@ -282,7 +275,7 @@ def griffinlim(
win_length=win_length, win_length=win_length,
window=window, window=window,
center=True, center=True,
pad_mode='reflect', pad_mode="reflect",
normalized=False, normalized=False,
onesided=True, onesided=True,
return_complex=True, return_complex=True,
...@@ -298,12 +291,9 @@ def griffinlim( ...@@ -298,12 +291,9 @@ def griffinlim(
tprev = rebuilt tprev = rebuilt
# Return the final phase estimates # Return the final phase estimates
waveform = torch.istft(specgram * angles, waveform = torch.istft(
n_fft=n_fft, specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
hop_length=hop_length, )
win_length=win_length,
window=window,
length=length)
# unpack batch # unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
...@@ -312,11 +302,7 @@ def griffinlim( ...@@ -312,11 +302,7 @@ def griffinlim(
def amplitude_to_DB( def amplitude_to_DB(
x: Tensor, x: Tensor, multiplier: float, amin: float, db_multiplier: float, top_db: Optional[float] = None
multiplier: float,
amin: float,
db_multiplier: float,
top_db: Optional[float] = None
) -> Tensor: ) -> Tensor:
r"""Turn a spectrogram from the power/amplitude scale to the decibel scale. r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
...@@ -354,11 +340,7 @@ def amplitude_to_DB( ...@@ -354,11 +340,7 @@ def amplitude_to_DB(
return x_db return x_db
def DB_to_amplitude( def DB_to_amplitude(x: Tensor, ref: float, power: float) -> Tensor:
x: Tensor,
ref: float,
power: float
) -> Tensor:
r"""Turn a tensor from the decibel scale to the power/amplitude scale. r"""Turn a tensor from the decibel scale to the power/amplitude scale.
Args: Args:
...@@ -383,7 +365,7 @@ def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float: ...@@ -383,7 +365,7 @@ def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
mels (float): Frequency in Mels mels (float): Frequency in Mels
""" """
if mel_scale not in ['slaney', 'htk']: if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk": if mel_scale == "htk":
...@@ -417,11 +399,11 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor: ...@@ -417,11 +399,11 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
freqs (Tensor): Mels converted in Hz freqs (Tensor): Mels converted in Hz
""" """
if mel_scale not in ['slaney', 'htk']: if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk": if mel_scale == "htk":
return 700.0 * (10.0**(mels / 2595.0) - 1.0) return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
# Fill in the linear scale # Fill in the linear scale
f_min = 0.0 f_min = 0.0
...@@ -433,15 +415,15 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor: ...@@ -433,15 +415,15 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
min_log_mel = (min_log_hz - f_min) / f_sp min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0 logstep = math.log(6.4) / 27.0
log_t = (mels >= min_log_mel) log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
return freqs return freqs
def _create_triangular_filterbank( def _create_triangular_filterbank(
all_freqs: Tensor, all_freqs: Tensor,
f_pts: Tensor, f_pts: Tensor,
) -> Tensor: ) -> Tensor:
"""Create a triangular filter bank. """Create a triangular filter bank.
...@@ -466,13 +448,13 @@ def _create_triangular_filterbank( ...@@ -466,13 +448,13 @@ def _create_triangular_filterbank(
def melscale_fbanks( def melscale_fbanks(
n_freqs: int, n_freqs: int,
f_min: float, f_min: float,
f_max: float, f_max: float,
n_mels: int, n_mels: int,
sample_rate: int, sample_rate: int,
norm: Optional[str] = None, norm: Optional[str] = None,
mel_scale: str = "htk", mel_scale: str = "htk",
) -> Tensor: ) -> Tensor:
r"""Create a frequency bin conversion matrix. r"""Create a frequency bin conversion matrix.
...@@ -520,10 +502,10 @@ def melscale_fbanks( ...@@ -520,10 +502,10 @@ def melscale_fbanks(
if norm is not None and norm == "slaney": if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel # Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0) fb *= enorm.unsqueeze(0)
if (fb.max(dim=0).values == 0.).any(): if (fb.max(dim=0).values == 0.0).any():
warnings.warn( warnings.warn(
"At least one mel filterbank has all zero values. " "At least one mel filterbank has all zero values. "
f"The value for `n_mels` ({n_mels}) may be set too high. " f"The value for `n_mels` ({n_mels}) may be set too high. "
...@@ -534,11 +516,11 @@ def melscale_fbanks( ...@@ -534,11 +516,11 @@ def melscale_fbanks(
def linear_fbanks( def linear_fbanks(
n_freqs: int, n_freqs: int,
f_min: float, f_min: float,
f_max: float, f_max: float,
n_filter: int, n_filter: int,
sample_rate: int, sample_rate: int,
) -> Tensor: ) -> Tensor:
r"""Creates a linear triangular filterbank. r"""Creates a linear triangular filterbank.
...@@ -575,11 +557,7 @@ def linear_fbanks( ...@@ -575,11 +557,7 @@ def linear_fbanks(
return fb return fb
def create_dct( def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
n_mfcc: int,
n_mels: int,
norm: Optional[str]
) -> Tensor:
r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
normalized depending on norm. normalized depending on norm.
...@@ -605,10 +583,7 @@ def create_dct( ...@@ -605,10 +583,7 @@ def create_dct(
return dct.t() return dct.t()
def mu_law_encoding( def mu_law_encoding(x: Tensor, quantization_channels: int) -> Tensor:
x: Tensor,
quantization_channels: int
) -> Tensor:
r"""Encode signal based on mu-law companding. For more info see the r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -624,8 +599,10 @@ def mu_law_encoding( ...@@ -624,8 +599,10 @@ def mu_law_encoding(
""" """
mu = quantization_channels - 1.0 mu = quantization_channels - 1.0
if not x.is_floating_point(): if not x.is_floating_point():
warnings.warn("The input Tensor must be of floating type. \ warnings.warn(
This will be an error in the v0.12 release.") "The input Tensor must be of floating type. \
This will be an error in the v0.12 release."
)
x = x.to(torch.float) x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype) mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
...@@ -633,10 +610,7 @@ def mu_law_encoding( ...@@ -633,10 +610,7 @@ def mu_law_encoding(
return x_mu return x_mu
def mu_law_decoding( def mu_law_decoding(x_mu: Tensor, quantization_channels: int) -> Tensor:
x_mu: Tensor,
quantization_channels: int
) -> Tensor:
r"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -659,11 +633,7 @@ def mu_law_decoding( ...@@ -659,11 +633,7 @@ def mu_law_decoding(
return x return x
def phase_vocoder( def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor) -> Tensor:
complex_specgrams: Tensor,
rate: float,
phase_advance: Tensor
) -> Tensor:
r"""Given a STFT tensor, speed up in time without modifying pitch by a r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of ``rate``. factor of ``rate``.
...@@ -699,12 +669,7 @@ def phase_vocoder( ...@@ -699,12 +669,7 @@ def phase_vocoder(
# Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32 # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
# Note torch.real is a view so it does not incur any memory copy. # Note torch.real is a view so it does not incur any memory copy.
real_dtype = torch.real(complex_specgrams).dtype real_dtype = torch.real(complex_specgrams).dtype
time_steps = torch.arange( time_steps = torch.arange(0, complex_specgrams.size(-1), rate, device=complex_specgrams.device, dtype=real_dtype)
0,
complex_specgrams.size(-1),
rate,
device=complex_specgrams.device,
dtype=real_dtype)
alphas = time_steps % 1.0 alphas = time_steps % 1.0
phase_0 = complex_specgrams[..., :1].angle() phase_0 = complex_specgrams[..., :1].angle()
...@@ -739,12 +704,7 @@ def phase_vocoder( ...@@ -739,12 +704,7 @@ def phase_vocoder(
return complex_specgrams_stretch return complex_specgrams_stretch
def mask_along_axis_iid( def mask_along_axis_iid(specgrams: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor:
specgrams: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
...@@ -760,7 +720,7 @@ def mask_along_axis_iid( ...@@ -760,7 +720,7 @@ def mask_along_axis_iid(
""" """
if axis not in [2, 3]: if axis not in [2, 3]:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError("Only Frequency and Time masking are supported")
device = specgrams.device device = specgrams.device
dtype = specgrams.dtype dtype = specgrams.dtype
...@@ -781,12 +741,7 @@ def mask_along_axis_iid( ...@@ -781,12 +741,7 @@ def mask_along_axis_iid(
return specgrams return specgrams
def mask_along_axis( def mask_along_axis(specgram: Tensor, mask_param: int, mask_value: float, axis: int) -> Tensor:
specgram: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
...@@ -802,7 +757,7 @@ def mask_along_axis( ...@@ -802,7 +757,7 @@ def mask_along_axis(
Tensor: Masked spectrogram of dimensions `(channel, freq, time)` Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
""" """
if axis not in [1, 2]: if axis not in [1, 2]:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError("Only Frequency and Time masking are supported")
# pack batch # pack batch
shape = specgram.size() shape = specgram.size()
...@@ -827,11 +782,7 @@ def mask_along_axis( ...@@ -827,11 +782,7 @@ def mask_along_axis(
return specgram return specgram
def compute_deltas( def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate") -> Tensor:
specgram: Tensor,
win_length: int = 5,
mode: str = "replicate"
) -> Tensor:
r"""Compute delta coefficients of a tensor, usually a spectrogram: r"""Compute delta coefficients of a tensor, usually a spectrogram:
.. math:: .. math::
...@@ -880,12 +831,7 @@ def compute_deltas( ...@@ -880,12 +831,7 @@ def compute_deltas(
return output return output
def _compute_nccf( def _compute_nccf(waveform: Tensor, sample_rate: int, frame_time: float, freq_low: int) -> Tensor:
waveform: Tensor,
sample_rate: int,
frame_time: float,
freq_low: int
) -> Tensor:
r""" r"""
Compute Normalized Cross-Correlation Function (NCCF). Compute Normalized Cross-Correlation Function (NCCF).
...@@ -932,25 +878,17 @@ def _compute_nccf( ...@@ -932,25 +878,17 @@ def _compute_nccf(
return nccf return nccf
def _combine_max( def _combine_max(a: Tuple[Tensor, Tensor], b: Tuple[Tensor, Tensor], thresh: float = 0.99) -> Tuple[Tensor, Tensor]:
a: Tuple[Tensor, Tensor],
b: Tuple[Tensor, Tensor],
thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
""" """
Take value from first if bigger than a multiplicative factor of the second, elementwise. Take value from first if bigger than a multiplicative factor of the second, elementwise.
""" """
mask = (a[0] > thresh * b[0]) mask = a[0] > thresh * b[0]
values = mask * a[0] + ~mask * b[0] values = mask * a[0] + ~mask * b[0]
indices = mask * a[1] + ~mask * b[1] indices = mask * a[1] + ~mask * b[1]
return values, indices return values, indices
def _find_max_per_frame( def _find_max_per_frame(nccf: Tensor, sample_rate: int, freq_high: int) -> Tensor:
nccf: Tensor,
sample_rate: int,
freq_high: int
) -> Tensor:
r""" r"""
For each frame, take the highest value of NCCF, For each frame, take the highest value of NCCF,
apply centered median smoothing, and convert to frequency. apply centered median smoothing, and convert to frequency.
...@@ -979,10 +917,7 @@ def _find_max_per_frame( ...@@ -979,10 +917,7 @@ def _find_max_per_frame(
return indices return indices
def _median_smoothing( def _median_smoothing(indices: Tensor, win_length: int) -> Tensor:
indices: Tensor,
win_length: int
) -> Tensor:
r""" r"""
Apply median smoothing to the 1D tensor over the given window. Apply median smoothing to the 1D tensor over the given window.
""" """
...@@ -991,9 +926,7 @@ def _median_smoothing( ...@@ -991,9 +926,7 @@ def _median_smoothing(
pad_length = (win_length - 1) // 2 pad_length = (win_length - 1) // 2
# "replicate" padding in any dimension # "replicate" padding in any dimension
indices = torch.nn.functional.pad( indices = torch.nn.functional.pad(indices, (pad_length, 0), mode="constant", value=0.0)
indices, (pad_length, 0), mode="constant", value=0.
)
indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1) indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
roll = indices.unfold(-1, win_length, 1) roll = indices.unfold(-1, win_length, 1)
...@@ -1003,12 +936,12 @@ def _median_smoothing( ...@@ -1003,12 +936,12 @@ def _median_smoothing(
def detect_pitch_frequency( def detect_pitch_frequency(
waveform: Tensor, waveform: Tensor,
sample_rate: int, sample_rate: int,
frame_time: float = 10 ** (-2), frame_time: float = 10 ** (-2),
win_length: int = 30, win_length: int = 30,
freq_low: int = 85, freq_low: int = 85,
freq_high: int = 3400, freq_high: int = 3400,
) -> Tensor: ) -> Tensor:
r"""Detect pitch frequency. r"""Detect pitch frequency.
...@@ -1075,8 +1008,7 @@ def sliding_window_cmn( ...@@ -1075,8 +1008,7 @@ def sliding_window_cmn(
last_window_start = last_window_end = -1 last_window_start = last_window_end = -1
cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device) cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device) cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cmn_specgram = torch.zeros( cmn_specgram = torch.zeros(num_channels, num_frames, num_feats, dtype=dtype, device=device)
num_channels, num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames): for t in range(num_frames):
window_start = 0 window_start = 0
window_end = 0 window_end = 0
...@@ -1093,12 +1025,12 @@ def sliding_window_cmn( ...@@ -1093,12 +1025,12 @@ def sliding_window_cmn(
if window_end > t: if window_end > t:
window_end = max(t + 1, min_cmn_window) window_end = max(t + 1, min_cmn_window)
if window_end > num_frames: if window_end > num_frames:
window_start -= (window_end - num_frames) window_start -= window_end - num_frames
window_end = num_frames window_end = num_frames
if window_start < 0: if window_start < 0:
window_start = 0 window_start = 0
if last_window_start == -1: if last_window_start == -1:
input_part = specgram[:, window_start: window_end - window_start, :] input_part = specgram[:, window_start : window_end - window_start, :]
cur_sum += torch.sum(input_part, 1) cur_sum += torch.sum(input_part, 1)
if norm_vars: if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :] cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
...@@ -1107,24 +1039,23 @@ def sliding_window_cmn( ...@@ -1107,24 +1039,23 @@ def sliding_window_cmn(
frame_to_remove = specgram[:, last_window_start, :] frame_to_remove = specgram[:, last_window_start, :]
cur_sum -= frame_to_remove cur_sum -= frame_to_remove
if norm_vars: if norm_vars:
cur_sumsq -= (frame_to_remove ** 2) cur_sumsq -= frame_to_remove ** 2
if window_end > last_window_end: if window_end > last_window_end:
frame_to_add = specgram[:, last_window_end, :] frame_to_add = specgram[:, last_window_end, :]
cur_sum += frame_to_add cur_sum += frame_to_add
if norm_vars: if norm_vars:
cur_sumsq += (frame_to_add ** 2) cur_sumsq += frame_to_add ** 2
window_frames = window_end - window_start window_frames = window_end - window_start
last_window_start = window_start last_window_start = window_start
last_window_end = window_end last_window_end = window_end
cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
if norm_vars: if norm_vars:
if window_frames == 1: if window_frames == 1:
cmn_specgram[:, t, :] = torch.zeros( cmn_specgram[:, t, :] = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
num_channels, num_feats, dtype=dtype, device=device)
else: else:
variance = cur_sumsq variance = cur_sumsq
variance = variance / window_frames variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2)) variance -= (cur_sum ** 2) / (window_frames ** 2)
variance = torch.pow(variance, -0.5) variance = torch.pow(variance, -0.5)
cmn_specgram[:, t, :] *= variance cmn_specgram[:, t, :] *= variance
...@@ -1135,13 +1066,13 @@ def sliding_window_cmn( ...@@ -1135,13 +1066,13 @@ def sliding_window_cmn(
def spectral_centroid( def spectral_centroid(
waveform: Tensor, waveform: Tensor,
sample_rate: int, sample_rate: int,
pad: int, pad: int,
window: Tensor, window: Tensor,
n_fft: int, n_fft: int,
hop_length: int, hop_length: int,
win_length: int, win_length: int,
) -> Tensor: ) -> Tensor:
r""" r"""
Compute the spectral centroid for each channel along the time axis. Compute the spectral centroid for each channel along the time axis.
...@@ -1161,10 +1092,17 @@ def spectral_centroid( ...@@ -1161,10 +1092,17 @@ def spectral_centroid(
Returns: Returns:
Tensor: Dimension `(..., time)` Tensor: Dimension `(..., time)`
""" """
specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length, specgram = spectrogram(
win_length=win_length, power=1., normalized=False) waveform,
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, pad=pad,
device=specgram.device).reshape((-1, 1)) window=window,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=1.0,
normalized=False,
)
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, device=specgram.device).reshape((-1, 1))
freq_dim = -2 freq_dim = -2
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
...@@ -1201,42 +1139,37 @@ def apply_codec( ...@@ -1201,42 +1139,37 @@ def apply_codec(
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`. If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
""" """
bytes = io.BytesIO() bytes = io.BytesIO()
torchaudio.backend.sox_io_backend.save(bytes, torchaudio.backend.sox_io_backend.save(
waveform, bytes, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
sample_rate, )
channels_first,
compression,
format,
encoding,
bits_per_sample
)
bytes.seek(0) bytes.seek(0)
augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file( augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format) bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format
)
return augmented return augmented
@_mod_utils.requires_kaldi() @_mod_utils.requires_kaldi()
def compute_kaldi_pitch( def compute_kaldi_pitch(
waveform: torch.Tensor, waveform: torch.Tensor,
sample_rate: float, sample_rate: float,
frame_length: float = 25.0, frame_length: float = 25.0,
frame_shift: float = 10.0, frame_shift: float = 10.0,
min_f0: float = 50, min_f0: float = 50,
max_f0: float = 400, max_f0: float = 400,
soft_min_f0: float = 10.0, soft_min_f0: float = 10.0,
penalty_factor: float = 0.1, penalty_factor: float = 0.1,
lowpass_cutoff: float = 1000, lowpass_cutoff: float = 1000,
resample_frequency: float = 4000, resample_frequency: float = 4000,
delta_pitch: float = 0.005, delta_pitch: float = 0.005,
nccf_ballast: float = 7000, nccf_ballast: float = 7000,
lowpass_filter_width: int = 1, lowpass_filter_width: int = 1,
upsample_filter_width: int = 5, upsample_filter_width: int = 5,
max_frames_latency: int = 0, max_frames_latency: int = 0,
frames_per_chunk: int = 0, frames_per_chunk: int = 0,
simulate_first_pass_online: bool = False, simulate_first_pass_online: bool = False,
recompute_frame: int = 500, recompute_frame: int = 500,
snip_edges: bool = True, snip_edges: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
"""Extract pitch based on method described in *A pitch extraction algorithm tuned """Extract pitch based on method described in *A pitch extraction algorithm tuned
for automatic speech recognition* [:footcite:`6854049`]. for automatic speech recognition* [:footcite:`6854049`].
...@@ -1302,11 +1235,24 @@ def compute_kaldi_pitch( ...@@ -1302,11 +1235,24 @@ def compute_kaldi_pitch(
shape = waveform.shape shape = waveform.shape
waveform = waveform.reshape(-1, shape[-1]) waveform = waveform.reshape(-1, shape[-1])
result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch( result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
waveform, sample_rate, frame_length, frame_shift, waveform,
min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff, sample_rate,
resample_frequency, delta_pitch, nccf_ballast, frame_length,
lowpass_filter_width, upsample_filter_width, max_frames_latency, frame_shift,
frames_per_chunk, simulate_first_pass_online, recompute_frame, min_f0,
max_f0,
soft_min_f0,
penalty_factor,
lowpass_cutoff,
resample_frequency,
delta_pitch,
nccf_ballast,
lowpass_filter_width,
upsample_filter_width,
max_frames_latency,
frames_per_chunk,
simulate_first_pass_online,
recompute_frame,
snip_edges, snip_edges,
) )
result = result.reshape(shape[:-1] + result.shape[-2:]) result = result.reshape(shape[:-1] + result.shape[-2:])
...@@ -1314,15 +1260,16 @@ def compute_kaldi_pitch( ...@@ -1314,15 +1260,16 @@ def compute_kaldi_pitch(
def _get_sinc_resample_kernel( def _get_sinc_resample_kernel(
orig_freq: int, orig_freq: int,
new_freq: int, new_freq: int,
gcd: int, gcd: int,
lowpass_filter_width: int, lowpass_filter_width: int,
rolloff: float, rolloff: float,
resampling_method: str, resampling_method: str,
beta: Optional[float], beta: Optional[float],
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None): dtype: Optional[torch.dtype] = None,
):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
raise Exception( raise Exception(
...@@ -1334,8 +1281,8 @@ def _get_sinc_resample_kernel( ...@@ -1334,8 +1281,8 @@ def _get_sinc_resample_kernel(
"For more information, please refer to https://github.com/pytorch/audio/issues/1487." "For more information, please refer to https://github.com/pytorch/audio/issues/1487."
) )
if resampling_method not in ['sinc_interpolation', 'kaiser_window']: if resampling_method not in ["sinc_interpolation", "kaiser_window"]:
raise ValueError('Invalid resampling method: {}'.format(resampling_method)) raise ValueError("Invalid resampling method: {}".format(resampling_method))
orig_freq = int(orig_freq) // gcd orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd new_freq = int(new_freq) // gcd
...@@ -1381,7 +1328,7 @@ def _get_sinc_resample_kernel( ...@@ -1381,7 +1328,7 @@ def _get_sinc_resample_kernel(
# we do not use built in torch windows here as we need to evaluate the window # we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid. # at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation": if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2 window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else: else:
# kaiser_window # kaiser_window
if beta is None: if beta is None:
...@@ -1389,7 +1336,7 @@ def _get_sinc_resample_kernel( ...@@ -1389,7 +1336,7 @@ def _get_sinc_resample_kernel(
beta_tensor = torch.tensor(float(beta)) beta_tensor = torch.tensor(float(beta))
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t)
kernel.mul_(window) kernel.mul_(window)
kernels.append(kernel) kernels.append(kernel)
...@@ -1401,12 +1348,12 @@ def _get_sinc_resample_kernel( ...@@ -1401,12 +1348,12 @@ def _get_sinc_resample_kernel(
def _apply_sinc_resample_kernel( def _apply_sinc_resample_kernel(
waveform: Tensor, waveform: Tensor,
orig_freq: int, orig_freq: int,
new_freq: int, new_freq: int,
gcd: int, gcd: int,
kernel: Tensor, kernel: Tensor,
width: int, width: int,
): ):
orig_freq = int(orig_freq) // gcd orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd new_freq = int(new_freq) // gcd
...@@ -1428,13 +1375,13 @@ def _apply_sinc_resample_kernel( ...@@ -1428,13 +1375,13 @@ def _apply_sinc_resample_kernel(
def resample( def resample(
waveform: Tensor, waveform: Tensor,
orig_freq: int, orig_freq: int,
new_freq: int, new_freq: int,
lowpass_filter_width: int = 6, lowpass_filter_width: int = 6,
rolloff: float = 0.99, rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation", resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None, beta: Optional[float] = None,
) -> Tensor: ) -> Tensor:
r"""Resamples the waveform at the new frequency using bandlimited interpolation. r"""Resamples the waveform at the new frequency using bandlimited interpolation.
...@@ -1467,8 +1414,17 @@ def resample( ...@@ -1467,8 +1414,17 @@ def resample(
gcd = math.gcd(int(orig_freq), int(new_freq)) gcd = math.gcd(int(orig_freq), int(new_freq))
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff, kernel, width = _get_sinc_resample_kernel(
resampling_method, beta, waveform.device, waveform.dtype) orig_freq,
new_freq,
gcd,
lowpass_filter_width,
rolloff,
resampling_method,
beta,
waveform.device,
waveform.dtype,
)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled return resampled
...@@ -1557,25 +1513,24 @@ def pitch_shift( ...@@ -1557,25 +1513,24 @@ def pitch_shift(
ori_len = shape[-1] ori_len = shape[-1]
rate = 2.0 ** (-float(n_steps) / bins_per_octave) rate = 2.0 ** (-float(n_steps) / bins_per_octave)
spec_f = torch.stft(input=waveform, spec_f = torch.stft(
n_fft=n_fft, input=waveform,
hop_length=hop_length, n_fft=n_fft,
win_length=win_length, hop_length=hop_length,
window=window, win_length=win_length,
center=True, window=window,
pad_mode='reflect', center=True,
normalized=False, pad_mode="reflect",
onesided=True, normalized=False,
return_complex=True) onesided=True,
return_complex=True,
)
phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None] phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None]
spec_stretch = phase_vocoder(spec_f, rate, phase_advance) spec_stretch = phase_vocoder(spec_f, rate, phase_advance)
len_stretch = int(round(ori_len / rate)) len_stretch = int(round(ori_len / rate))
waveform_stretch = torch.istft(spec_stretch, waveform_stretch = torch.istft(
n_fft=n_fft, spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
hop_length=hop_length, )
win_length=win_length,
window=window,
length=len_stretch)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate) waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
shift_len = waveform_shift.size()[-1] shift_len = waveform_shift.size()[-1]
if shift_len > ori_len: if shift_len > ori_len:
...@@ -1617,7 +1572,7 @@ def rnnt_loss( ...@@ -1617,7 +1572,7 @@ def rnnt_loss(
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size `(batch)`, Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size `(batch)`,
otherwise scalar. otherwise scalar.
""" """
if reduction not in ['none', 'mean', 'sum']: if reduction not in ["none", "mean", "sum"]:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'") raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
if blank < 0: # reinterpret blank index if blank < 0. if blank < 0: # reinterpret blank index if blank < 0.
...@@ -1632,9 +1587,9 @@ def rnnt_loss( ...@@ -1632,9 +1587,9 @@ def rnnt_loss(
clamp=clamp, clamp=clamp,
) )
if reduction == 'mean': if reduction == "mean":
return costs.mean() return costs.mean()
elif reduction == 'sum': elif reduction == "sum":
return costs.sum() return costs.sum()
return costs return costs
...@@ -7,23 +7,23 @@ import torch ...@@ -7,23 +7,23 @@ import torch
from torch import Tensor from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
if _mod_utils.is_module_available('kaldi_io', 'numpy'): if _mod_utils.is_module_available("kaldi_io", "numpy"):
import numpy as np
import kaldi_io import kaldi_io
import numpy as np
__all__ = [ __all__ = [
'read_vec_int_ark', "read_vec_int_ark",
'read_vec_flt_scp', "read_vec_flt_scp",
'read_vec_flt_ark', "read_vec_flt_ark",
'read_mat_scp', "read_mat_scp",
'read_mat_ark', "read_mat_ark",
] ]
def _convert_method_output_to_tensor(file_or_fd: Any, def _convert_method_output_to_tensor(
fn: Callable, file_or_fd: Any, fn: Callable, convert_contiguous: bool = False
convert_contiguous: bool = False) -> Iterable[Tuple[str, Tensor]]: ) -> Iterable[Tuple[str, Tensor]]:
r"""Takes a method invokes it. The output is converted to a tensor. r"""Takes a method invokes it. The output is converted to a tensor.
Args: Args:
...@@ -42,7 +42,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any, ...@@ -42,7 +42,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
yield key, torch.from_numpy(np_arr) yield key, torch.from_numpy(np_arr)
@_mod_utils.requires_module('kaldi_io', 'numpy') @_mod_utils.requires_module("kaldi_io", "numpy")
def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<int>) tuples, which reads from the ark file/stream. r"""Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
...@@ -62,7 +62,7 @@ def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: ...@@ -62,7 +62,7 @@ def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_int_ark, convert_contiguous=True) return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_int_ark, convert_contiguous=True)
@_mod_utils.requires_module('kaldi_io', 'numpy') @_mod_utils.requires_module("kaldi_io", "numpy")
def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<float32/float64>) tuples, read according to Kaldi scp. r"""Create generator of (key,vector<float32/float64>) tuples, read according to Kaldi scp.
...@@ -79,7 +79,7 @@ def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: ...@@ -79,7 +79,7 @@ def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_scp) return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_scp)
@_mod_utils.requires_module('kaldi_io', 'numpy') @_mod_utils.requires_module("kaldi_io", "numpy")
def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<float32/float64>) tuples, which reads from the ark file/stream. r"""Create generator of (key,vector<float32/float64>) tuples, which reads from the ark file/stream.
...@@ -96,7 +96,7 @@ def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: ...@@ -96,7 +96,7 @@ def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_ark) return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_ark)
@_mod_utils.requires_module('kaldi_io', 'numpy') @_mod_utils.requires_module("kaldi_io", "numpy")
def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,matrix<float32/float64>) tuples, read according to Kaldi scp. r"""Create generator of (key,matrix<float32/float64>) tuples, read according to Kaldi scp.
...@@ -113,7 +113,7 @@ def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: ...@@ -113,7 +113,7 @@ def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_scp) return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_scp)
@_mod_utils.requires_module('kaldi_io', 'numpy') @_mod_utils.requires_module("kaldi_io", "numpy")
def read_mat_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: def read_mat_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,matrix<float32/float64>) tuples, which reads from the ark file/stream. r"""Create generator of (key,matrix<float32/float64>) tuples, which reads from the ark file/stream.
......
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2 from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model, wav2vec2_model,
...@@ -13,19 +12,20 @@ from .wav2vec2 import ( ...@@ -13,19 +12,20 @@ from .wav2vec2 import (
hubert_large, hubert_large,
hubert_xlarge, hubert_xlarge,
) )
from .wavernn import WaveRNN
__all__ = [ __all__ = [
'Wav2Letter', "Wav2Letter",
'WaveRNN', "WaveRNN",
'ConvTasNet', "ConvTasNet",
'DeepSpeech', "DeepSpeech",
'Wav2Vec2Model', "Wav2Vec2Model",
'wav2vec2_model', "wav2vec2_model",
'wav2vec2_base', "wav2vec2_base",
'wav2vec2_large', "wav2vec2_large",
'wav2vec2_large_lv60k', "wav2vec2_large_lv60k",
'hubert_base', "hubert_base",
'hubert_large', "hubert_large",
'hubert_xlarge', "hubert_xlarge",
'Tacotron2', "Tacotron2",
] ]
...@@ -35,9 +35,7 @@ class ConvBlock(torch.nn.Module): ...@@ -35,9 +35,7 @@ class ConvBlock(torch.nn.Module):
super().__init__() super().__init__()
self.conv_layers = torch.nn.Sequential( self.conv_layers = torch.nn.Sequential(
torch.nn.Conv1d( torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
in_channels=io_channels, out_channels=hidden_channels, kernel_size=1
),
torch.nn.PReLU(), torch.nn.PReLU(),
torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
torch.nn.Conv1d( torch.nn.Conv1d(
...@@ -55,17 +53,11 @@ class ConvBlock(torch.nn.Module): ...@@ -55,17 +53,11 @@ class ConvBlock(torch.nn.Module):
self.res_out = ( self.res_out = (
None None
if no_residual if no_residual
else torch.nn.Conv1d( else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
in_channels=hidden_channels, out_channels=io_channels, kernel_size=1
)
)
self.skip_out = torch.nn.Conv1d(
in_channels=hidden_channels, out_channels=io_channels, kernel_size=1
) )
self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
def forward( def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
self, input: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
feature = self.conv_layers(input) feature = self.conv_layers(input)
if self.res_out is None: if self.res_out is None:
residual = None residual = None
...@@ -110,12 +102,8 @@ class MaskGenerator(torch.nn.Module): ...@@ -110,12 +102,8 @@ class MaskGenerator(torch.nn.Module):
self.input_dim = input_dim self.input_dim = input_dim
self.num_sources = num_sources self.num_sources = num_sources
self.input_norm = torch.nn.GroupNorm( self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
num_groups=1, num_channels=input_dim, eps=1e-8 self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
)
self.input_conv = torch.nn.Conv1d(
in_channels=input_dim, out_channels=num_feats, kernel_size=1
)
self.receptive_field = 0 self.receptive_field = 0
self.conv_layers = torch.nn.ModuleList([]) self.conv_layers = torch.nn.ModuleList([])
...@@ -133,12 +121,12 @@ class MaskGenerator(torch.nn.Module): ...@@ -133,12 +121,12 @@ class MaskGenerator(torch.nn.Module):
no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)), no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
) )
) )
self.receptive_field += ( self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
)
self.output_prelu = torch.nn.PReLU() self.output_prelu = torch.nn.PReLU()
self.output_conv = torch.nn.Conv1d( self.output_conv = torch.nn.Conv1d(
in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1, in_channels=num_feats,
out_channels=input_dim * num_sources,
kernel_size=1,
) )
if msk_activate == "sigmoid": if msk_activate == "sigmoid":
self.mask_activate = torch.nn.Sigmoid() self.mask_activate = torch.nn.Sigmoid()
...@@ -239,9 +227,7 @@ class ConvTasNet(torch.nn.Module): ...@@ -239,9 +227,7 @@ class ConvTasNet(torch.nn.Module):
bias=False, bias=False,
) )
def _align_num_frames_with_strides( def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
self, input: torch.Tensor
) -> Tuple[torch.Tensor, int]:
"""Pad input Tensor so that the end of the input tensor corresponds with """Pad input Tensor so that the end of the input tensor corresponds with
1. (if kernel size is odd) the center of the last convolution kernel 1. (if kernel size is odd) the center of the last convolution kernel
...@@ -294,9 +280,7 @@ class ConvTasNet(torch.nn.Module): ...@@ -294,9 +280,7 @@ class ConvTasNet(torch.nn.Module):
Tensor: 3D Tensor with shape [batch, channel==num_sources, frames] Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
""" """
if input.ndim != 3 or input.shape[1] != 1: if input.ndim != 3 or input.shape[1] != 1:
raise ValueError( raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}"
)
# B: batch size # B: batch size
# L: input frame length # L: input frame length
...@@ -309,13 +293,9 @@ class ConvTasNet(torch.nn.Module): ...@@ -309,13 +293,9 @@ class ConvTasNet(torch.nn.Module):
batch_size, num_padded_frames = padded.shape[0], padded.shape[2] batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
feats = self.encoder(padded) # B, F, M feats = self.encoder(padded) # B, F, M
masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
masked = masked.view( masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
batch_size * self.num_sources, self.enc_num_feats, -1
) # B*S, F, M
decoded = self.decoder(masked) # B*S, 1, L' decoded = self.decoder(masked) # B*S, 1, L'
output = decoded.view( output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
batch_size, self.num_sources, num_padded_frames
) # B, S, L'
if num_pads > 0: if num_pads > 0:
output = output[..., :-num_pads] # B, S, L output = output[..., :-num_pads] # B, S, L
return output return output
...@@ -10,11 +10,7 @@ class FullyConnected(torch.nn.Module): ...@@ -10,11 +10,7 @@ class FullyConnected(torch.nn.Module):
n_hidden: Internal hidden unit size. n_hidden: Internal hidden unit size.
""" """
def __init__(self, def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
n_feature: int,
n_hidden: int,
dropout: float,
relu_max_clip: int = 20) -> None:
super(FullyConnected, self).__init__() super(FullyConnected, self).__init__()
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
self.relu_max_clip = relu_max_clip self.relu_max_clip = relu_max_clip
...@@ -52,9 +48,7 @@ class DeepSpeech(torch.nn.Module): ...@@ -52,9 +48,7 @@ class DeepSpeech(torch.nn.Module):
self.fc1 = FullyConnected(n_feature, n_hidden, dropout) self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
self.bi_rnn = torch.nn.RNN( self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True
)
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
self.out = torch.nn.Linear(n_hidden, n_class) self.out = torch.nn.Linear(n_hidden, n_class)
...@@ -78,7 +72,7 @@ class DeepSpeech(torch.nn.Module): ...@@ -78,7 +72,7 @@ class DeepSpeech(torch.nn.Module):
# T x N x H # T x N x H
x, _ = self.bi_rnn(x) x, _ = self.bi_rnn(x)
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs # The fifth (non-recurrent) layer takes both the forward and backward units as inputs
x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:] x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
# T x N x H # T x N x H
x = self.fc4(x) x = self.fc4(x)
# T x N x H # T x N x H
......
...@@ -29,8 +29,8 @@ import warnings ...@@ -29,8 +29,8 @@ import warnings
from typing import Tuple, List, Optional, Union from typing import Tuple, List, Optional, Union
import torch import torch
from torch import nn
from torch import Tensor from torch import Tensor
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -39,9 +39,7 @@ __all__ = [ ...@@ -39,9 +39,7 @@ __all__ = [
] ]
def _get_linear_layer( def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear:
in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear"
) -> torch.nn.Linear:
r"""Linear layer with xavier uniform initialization. r"""Linear layer with xavier uniform initialization.
Args: Args:
...@@ -55,9 +53,7 @@ def _get_linear_layer( ...@@ -55,9 +53,7 @@ def _get_linear_layer(
(torch.nn.Linear): The corresponding linear layer. (torch.nn.Linear): The corresponding linear layer.
""" """
linear = torch.nn.Linear(in_dim, out_dim, bias=bias) linear = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_( torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
return linear return linear
...@@ -101,9 +97,7 @@ def _get_conv1d_layer( ...@@ -101,9 +97,7 @@ def _get_conv1d_layer(
bias=bias, bias=bias,
) )
torch.nn.init.xavier_uniform_( torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
return conv1d return conv1d
...@@ -194,9 +188,7 @@ class _Attention(nn.Module): ...@@ -194,9 +188,7 @@ class _Attention(nn.Module):
attention_location_kernel_size: int, attention_location_kernel_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.query_layer = _get_linear_layer( self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh")
attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = _get_linear_layer( self.memory_layer = _get_linear_layer(
encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh" encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
) )
...@@ -208,9 +200,7 @@ class _Attention(nn.Module): ...@@ -208,9 +200,7 @@ class _Attention(nn.Module):
) )
self.score_mask_value = -float("inf") self.score_mask_value = -float("inf")
def _get_alignment_energies( def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor
) -> Tensor:
r"""Get the alignment vector. r"""Get the alignment vector.
Args: Args:
...@@ -226,9 +216,7 @@ class _Attention(nn.Module): ...@@ -226,9 +216,7 @@ class _Attention(nn.Module):
processed_query = self.query_layer(query.unsqueeze(1)) processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat) processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v( energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
alignment = energies.squeeze(2) alignment = energies.squeeze(2)
return alignment return alignment
...@@ -256,9 +244,7 @@ class _Attention(nn.Module): ...@@ -256,9 +244,7 @@ class _Attention(nn.Module):
attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
""" """
alignment = self._get_alignment_energies( alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
attention_hidden_state, processed_memory, attention_weights_cat
)
alignment = alignment.masked_fill(mask, self.score_mask_value) alignment = alignment.masked_fill(mask, self.score_mask_value)
...@@ -281,10 +267,7 @@ class _Prenet(nn.Module): ...@@ -281,10 +267,7 @@ class _Prenet(nn.Module):
super().__init__() super().__init__()
in_sizes = [in_dim] + out_sizes[:-1] in_sizes = [in_dim] + out_sizes[:-1]
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)]
_get_linear_layer(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, out_sizes)
]
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
...@@ -488,9 +471,7 @@ class _Decoder(nn.Module): ...@@ -488,9 +471,7 @@ class _Decoder(nn.Module):
self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim]) self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim])
self.attention_rnn = nn.LSTMCell( self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
prenet_dim + encoder_embedding_dim, attention_rnn_dim
)
self.attention_layer = _Attention( self.attention_layer = _Attention(
attention_rnn_dim, attention_rnn_dim,
...@@ -500,13 +481,9 @@ class _Decoder(nn.Module): ...@@ -500,13 +481,9 @@ class _Decoder(nn.Module):
attention_location_kernel_size, attention_location_kernel_size,
) )
self.decoder_rnn = nn.LSTMCell( self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True)
attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True
)
self.linear_projection = _get_linear_layer( self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step)
decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step
)
self.gate_layer = _get_linear_layer( self.gate_layer = _get_linear_layer(
decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid" decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid"
...@@ -526,9 +503,7 @@ class _Decoder(nn.Module): ...@@ -526,9 +503,7 @@ class _Decoder(nn.Module):
n_batch = memory.size(0) n_batch = memory.size(0)
dtype = memory.dtype dtype = memory.dtype
device = memory.device device = memory.device
decoder_input = torch.zeros( decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device
)
return decoder_input return decoder_input
def _initialize_decoder_states( def _initialize_decoder_states(
...@@ -557,27 +532,15 @@ class _Decoder(nn.Module): ...@@ -557,27 +532,15 @@ class _Decoder(nn.Module):
dtype = memory.dtype dtype = memory.dtype
device = memory.device device = memory.device
attention_hidden = torch.zeros( attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
n_batch, self.attention_rnn_dim, dtype=dtype, device=device attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
)
attention_cell = torch.zeros(
n_batch, self.attention_rnn_dim, dtype=dtype, device=device
)
decoder_hidden = torch.zeros( decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
n_batch, self.decoder_rnn_dim, dtype=dtype, device=device decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
)
decoder_cell = torch.zeros(
n_batch, self.decoder_rnn_dim, dtype=dtype, device=device
)
attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device) attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
attention_weights_cum = torch.zeros( attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
n_batch, max_time, dtype=dtype, device=device attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device)
)
attention_context = torch.zeros(
n_batch, self.encoder_embedding_dim, dtype=dtype, device=device
)
processed_memory = self.attention_layer.memory_layer(memory) processed_memory = self.attention_layer.memory_layer(memory)
...@@ -688,16 +651,10 @@ class _Decoder(nn.Module): ...@@ -688,16 +651,10 @@ class _Decoder(nn.Module):
""" """
cell_input = torch.cat((decoder_input, attention_context), -1) cell_input = torch.cat((decoder_input, attention_context), -1)
attention_hidden, attention_cell = self.attention_rnn( attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell))
cell_input, (attention_hidden, attention_cell) attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training)
)
attention_hidden = F.dropout(
attention_hidden, self.attention_dropout, self.training
)
attention_weights_cat = torch.cat( attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1)
(attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1
)
attention_context, attention_weights = self.attention_layer( attention_context, attention_weights = self.attention_layer(
attention_hidden, memory, processed_memory, attention_weights_cat, mask attention_hidden, memory, processed_memory, attention_weights_cat, mask
) )
...@@ -705,14 +662,10 @@ class _Decoder(nn.Module): ...@@ -705,14 +662,10 @@ class _Decoder(nn.Module):
attention_weights_cum += attention_weights attention_weights_cum += attention_weights
decoder_input = torch.cat((attention_hidden, attention_context), -1) decoder_input = torch.cat((attention_hidden, attention_context), -1)
decoder_hidden, decoder_cell = self.decoder_rnn( decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
decoder_input, (decoder_hidden, decoder_cell)
)
decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training) decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat( decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1)
(decoder_hidden, attention_context), dim=1
)
decoder_output = self.linear_projection(decoder_hidden_attention_context) decoder_output = self.linear_projection(decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context)
...@@ -819,15 +772,11 @@ class _Decoder(nn.Module): ...@@ -819,15 +772,11 @@ class _Decoder(nn.Module):
n_batch = memory.size(0) n_batch = memory.size(0)
dtype = memory.dtype dtype = memory.dtype
device = memory.device device = memory.device
decoder_input = torch.zeros( decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device
)
return decoder_input return decoder_input
@torch.jit.export @torch.jit.export
def infer(self, def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
memory: Tensor,
memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Decoder inference """Decoder inference
Args: Args:
...@@ -905,16 +854,14 @@ class _Decoder(nn.Module): ...@@ -905,16 +854,14 @@ class _Decoder(nn.Module):
if len(mel_specgrams) == self.decoder_max_step: if len(mel_specgrams) == self.decoder_max_step:
warnings.warn( warnings.warn(
"Reached max decoder steps. The generated spectrogram might not cover " "Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
"the whole transcript.") )
mel_specgrams = torch.cat(mel_specgrams, dim=0) mel_specgrams = torch.cat(mel_specgrams, dim=0)
gate_outputs = torch.cat(gate_outputs, dim=0) gate_outputs = torch.cat(gate_outputs, dim=0)
alignments = torch.cat(alignments, dim=0) alignments = torch.cat(alignments, dim=0)
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs( mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
mel_specgrams, gate_outputs, alignments
)
return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
...@@ -984,9 +931,7 @@ class Tacotron2(nn.Module): ...@@ -984,9 +931,7 @@ class Tacotron2(nn.Module):
self.n_frames_per_step = n_frames_per_step self.n_frames_per_step = n_frames_per_step
self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim) self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight) torch.nn.init.xavier_uniform_(self.embedding.weight)
self.encoder = _Encoder( self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size)
encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size
)
self.decoder = _Decoder( self.decoder = _Decoder(
n_mels, n_mels,
n_frames_per_step, n_frames_per_step,
...@@ -1003,9 +948,7 @@ class Tacotron2(nn.Module): ...@@ -1003,9 +948,7 @@ class Tacotron2(nn.Module):
prenet_dim, prenet_dim,
gate_threshold, gate_threshold,
) )
self.postnet = _Postnet( self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution)
n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution
)
def forward( def forward(
self, self,
...@@ -1094,9 +1037,7 @@ class Tacotron2(nn.Module): ...@@ -1094,9 +1037,7 @@ class Tacotron2(nn.Module):
embedded_inputs = self.embedding(tokens).transpose(1, 2) embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, lengths) encoder_outputs = self.encoder(embedded_inputs, lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer( mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
encoder_outputs, lengths
)
mel_outputs_postnet = self.postnet(mel_specgram) mel_outputs_postnet = self.postnet(mel_specgram)
mel_outputs_postnet = mel_specgram + mel_outputs_postnet mel_outputs_postnet = mel_specgram + mel_outputs_postnet
......
...@@ -19,9 +19,7 @@ class Wav2Letter(nn.Module): ...@@ -19,9 +19,7 @@ class Wav2Letter(nn.Module):
num_features (int, optional): Number of input features that the network will receive (Default: ``1``). num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
""" """
def __init__(self, num_classes: int = 40, def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
input_type: str = "waveform",
num_features: int = 1) -> None:
super(Wav2Letter, self).__init__() super(Wav2Letter, self).__init__()
acoustic_num_features = 250 if input_type == "waveform" else num_features acoustic_num_features = 250 if input_type == "waveform" else num_features
...@@ -47,13 +45,13 @@ class Wav2Letter(nn.Module): ...@@ -47,13 +45,13 @@ class Wav2Letter(nn.Module):
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True) nn.ReLU(inplace=True),
) )
if input_type == "waveform": if input_type == "waveform":
waveform_model = nn.Sequential( waveform_model = nn.Sequential(
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
nn.ReLU(inplace=True) nn.ReLU(inplace=True),
) )
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
......
from . import utils
from .model import ( from .model import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model, wav2vec2_model,
...@@ -8,16 +9,15 @@ from .model import ( ...@@ -8,16 +9,15 @@ from .model import (
hubert_large, hubert_large,
hubert_xlarge, hubert_xlarge,
) )
from . import utils
__all__ = [ __all__ = [
'Wav2Vec2Model', "Wav2Vec2Model",
'wav2vec2_model', "wav2vec2_model",
'wav2vec2_base', "wav2vec2_base",
'wav2vec2_large', "wav2vec2_large",
'wav2vec2_large_lv60k', "wav2vec2_large_lv60k",
'hubert_base', "hubert_base",
'hubert_large', "hubert_large",
'hubert_xlarge', "hubert_xlarge",
'utils', "utils",
] ]
...@@ -10,24 +10,25 @@ _LG = logging.getLogger(__name__) ...@@ -10,24 +10,25 @@ _LG = logging.getLogger(__name__)
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
"""Layer norm with transpose""" """Layer norm with transpose"""
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
x = input.transpose(-2, -1) x = input.transpose(-2, -1)
x = nn.functional.layer_norm( x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.transpose(-2, -1) x = x.transpose(-2, -1)
return x return x
class ConvLayerBlock(Module): class ConvLayerBlock(Module):
"""Convolution unit of FeatureExtractor""" """Convolution unit of FeatureExtractor"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int, kernel_size: int,
stride: int, stride: int,
bias: bool, bias: bool,
layer_norm: Optional[Module], layer_norm: Optional[Module],
): ):
super().__init__() super().__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
...@@ -42,9 +43,9 @@ class ConvLayerBlock(Module): ...@@ -42,9 +43,9 @@ class ConvLayerBlock(Module):
) )
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
length: Optional[Tensor], length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
""" """
Args: Args:
...@@ -60,7 +61,7 @@ class ConvLayerBlock(Module): ...@@ -60,7 +61,7 @@ class ConvLayerBlock(Module):
x = nn.functional.gelu(x) x = nn.functional.gelu(x)
if length is not None: if length is not None:
length = torch.div(length - self.kernel_size, self.stride, rounding_mode='floor') + 1 length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
# When input length is 0, the resulting length can be negative. So fix it here. # When input length is 0, the resulting length can be negative. So fix it here.
length = torch.max(torch.zeros_like(length), length) length = torch.max(torch.zeros_like(length), length)
return x, length return x, length
...@@ -73,17 +74,18 @@ class FeatureExtractor(Module): ...@@ -73,17 +74,18 @@ class FeatureExtractor(Module):
conv_layers (nn.ModuleList): conv_layers (nn.ModuleList):
convolution layers convolution layers
""" """
def __init__( def __init__(
self, self,
conv_layers: nn.ModuleList, conv_layers: nn.ModuleList,
): ):
super().__init__() super().__init__()
self.conv_layers = conv_layers self.conv_layers = conv_layers
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
length: Optional[Tensor], length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
""" """
Args: Args:
...@@ -100,9 +102,7 @@ class FeatureExtractor(Module): ...@@ -100,9 +102,7 @@ class FeatureExtractor(Module):
Valid length of each output sample. shape: ``[batch, ]``. Valid length of each output sample. shape: ``[batch, ]``.
""" """
if x.ndim != 2: if x.ndim != 2:
raise ValueError( raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}")
"Expected the input Tensor to be 2D (batch, time), "
"but received {list(x.shape)}")
x = x.unsqueeze(1) # (batch, channel==1, frame) x = x.unsqueeze(1) # (batch, channel==1, frame)
for layer in self.conv_layers: for layer in self.conv_layers:
...@@ -121,15 +121,19 @@ class FeatureProjection(Module): ...@@ -121,15 +121,19 @@ class FeatureProjection(Module):
out_features (int): Output feature dim. out_features (int): Output feature dim.
dropout (float): Dropout probability. dropout (float): Dropout probability.
""" """
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
dropout: float, dropout: float,
): ):
super().__init__() super().__init__()
self.layer_norm = nn.LayerNorm(in_features) self.layer_norm = nn.LayerNorm(in_features)
self.projection = nn.Linear(in_features, out_features,) self.projection = nn.Linear(
in_features,
out_features,
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
...@@ -154,11 +158,12 @@ class ConvolutionalPositionalEmbedding(Module): ...@@ -154,11 +158,12 @@ class ConvolutionalPositionalEmbedding(Module):
kernel_size (int): The number of frames to be use. kernel_size (int): The number of frames to be use.
groups (int): The number of groups in feature dimensions. groups (int): The number of groups in feature dimensions.
""" """
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim: int,
kernel_size: int, kernel_size: int,
groups: int, groups: int,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -178,11 +183,8 @@ class ConvolutionalPositionalEmbedding(Module): ...@@ -178,11 +183,8 @@ class ConvolutionalPositionalEmbedding(Module):
# normally we would do `if isinstance(...)` but this class is not accessible # normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly. # because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if ( if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
hook.__module__ == 'torch.nn.utils.weight_norm' and _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
hook.__class__.__name__ == 'WeightNorm'
):
_LG.warning('Removing weight_norm from %s', self.__class__.__name__)
torch.nn.utils.remove_weight_norm(self.conv) torch.nn.utils.remove_weight_norm(self.conv)
return self return self
...@@ -197,7 +199,7 @@ class ConvolutionalPositionalEmbedding(Module): ...@@ -197,7 +199,7 @@ class ConvolutionalPositionalEmbedding(Module):
x = x.transpose(-2, -1) x = x.transpose(-2, -1)
x = self.conv(x) x = self.conv(x)
if self.num_remove > 0: if self.num_remove > 0:
x = x[..., :-self.num_remove] x = x[..., : -self.num_remove]
x = torch.nn.functional.gelu(x) x = torch.nn.functional.gelu(x)
x = x.transpose(-2, -1) x = x.transpose(-2, -1)
return x return x
...@@ -212,11 +214,12 @@ class SelfAttention(Module): ...@@ -212,11 +214,12 @@ class SelfAttention(Module):
dropout (float, optional): dropout (float, optional):
Dropout probabiliry on attn_output_weights. Default: ``0.0`` Dropout probabiliry on attn_output_weights. Default: ``0.0``
""" """
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
): ):
super().__init__() super().__init__()
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
...@@ -236,9 +239,9 @@ class SelfAttention(Module): ...@@ -236,9 +239,9 @@ class SelfAttention(Module):
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
""" """
Args: Args:
...@@ -251,17 +254,13 @@ class SelfAttention(Module): ...@@ -251,17 +254,13 @@ class SelfAttention(Module):
""" """
if x.ndim != 3 or x.shape[2] != self.embed_dim: if x.ndim != 3 or x.shape[2] != self.embed_dim:
raise ValueError( raise ValueError(
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
f"Found {x.shape}."
) )
batch_size, length, embed_dim = x.size() batch_size, length, embed_dim = x.size()
if attention_mask is not None: if attention_mask is not None:
shape_ = (batch_size, 1, length, length) shape_ = (batch_size, 1, length, length)
if attention_mask.size() != shape_: if attention_mask.size() != shape_:
raise ValueError( raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.")
f"The expected attention mask shape is {shape_}. "
f"Found {attention_mask.size()}."
)
shape = (batch_size, length, self.num_heads, self.head_dim) shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
...@@ -283,14 +282,14 @@ class SelfAttention(Module): ...@@ -283,14 +282,14 @@ class SelfAttention(Module):
class FeedForward(Module): class FeedForward(Module):
"""Layer that follows attention layer in encoder layer. """Layer that follows attention layer in encoder layer."""
"""
def __init__( def __init__(
self, self,
io_features: int, io_features: int,
intermediate_features: int, intermediate_features: int,
intermediate_dropout: float, intermediate_dropout: float,
output_dropout: float, output_dropout: float,
): ):
super().__init__() super().__init__()
self.intermediate_dense = nn.Linear(io_features, intermediate_features) self.intermediate_dense = nn.Linear(io_features, intermediate_features)
...@@ -315,14 +314,14 @@ class FeedForward(Module): ...@@ -315,14 +314,14 @@ class FeedForward(Module):
class EncoderLayer(Module): class EncoderLayer(Module):
"""A layer unit in encoder. Combines multihead self attention and feed forward. """A layer unit in encoder. Combines multihead self attention and feed forward."""
"""
def __init__( def __init__(
self, self,
attention: Module, attention: Module,
dropout: float, dropout: float,
layer_norm_first: bool, layer_norm_first: bool,
feed_forward: Module, feed_forward: Module,
): ):
super().__init__() super().__init__()
self.attention = attention self.attention = attention
...@@ -333,9 +332,9 @@ class EncoderLayer(Module): ...@@ -333,9 +332,9 @@ class EncoderLayer(Module):
self.final_layer_norm = nn.LayerNorm(attention.embed_dim) self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
): ):
""" """
Args: Args:
...@@ -362,12 +361,12 @@ class EncoderLayer(Module): ...@@ -362,12 +361,12 @@ class EncoderLayer(Module):
class Transformer(Module): class Transformer(Module):
def __init__( def __init__(
self, self,
pos_conv_embed: Module, pos_conv_embed: Module,
dropout: float, dropout: float,
layers: Module, layers: Module,
layer_norm_first: bool, layer_norm_first: bool,
layer_drop: float, layer_drop: float,
): ):
super().__init__() super().__init__()
self.pos_conv_embed = pos_conv_embed self.pos_conv_embed = pos_conv_embed
...@@ -387,9 +386,9 @@ class Transformer(Module): ...@@ -387,9 +386,9 @@ class Transformer(Module):
return x return x
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
): ):
x = self._preprocess(x) x = self._preprocess(x)
for layer in self.layers: for layer in self.layers:
...@@ -402,14 +401,14 @@ class Transformer(Module): ...@@ -402,14 +401,14 @@ class Transformer(Module):
return x return x
def get_intermediate_outputs( def get_intermediate_outputs(
self, self,
x: Tensor, x: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None, num_layers: Optional[int] = None,
) -> List[Tensor]: ) -> List[Tensor]:
if num_layers is not None: if num_layers is not None:
if not 0 < num_layers <= len(self.layers): if not 0 < num_layers <= len(self.layers):
raise ValueError(f'`num_layers` must be between [1, {len(self.layers)}]') raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
ret: List[Tensor] = [] ret: List[Tensor] = []
x = self._preprocess(x) x = self._preprocess(x)
...@@ -423,18 +422,18 @@ class Transformer(Module): ...@@ -423,18 +422,18 @@ class Transformer(Module):
class Encoder(Module): class Encoder(Module):
def __init__( def __init__(
self, self,
feature_projection: Module, feature_projection: Module,
transformer: Module, transformer: Module,
): ):
super().__init__() super().__init__()
self.feature_projection = feature_projection self.feature_projection = feature_projection
self.transformer = transformer self.transformer = transformer
def _preprocess( def _preprocess(
self, self,
features: Tensor, features: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
x = self.feature_projection(features) x = self.feature_projection(features)
...@@ -450,30 +449,29 @@ class Encoder(Module): ...@@ -450,30 +449,29 @@ class Encoder(Module):
return x, mask return x, mask
def forward( def forward(
self, self,
features: Tensor, features: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
x, mask = self._preprocess(features, lengths) x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask) x = self.transformer(x, attention_mask=mask)
return x return x
def extract_features( def extract_features(
self, self,
features: Tensor, features: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None, num_layers: Optional[int] = None,
) -> List[Tensor]: ) -> List[Tensor]:
x, masks = self._preprocess(features, lengths) x, masks = self._preprocess(features, lengths)
return self.transformer.get_intermediate_outputs( return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
x, attention_mask=masks, num_layers=num_layers)
################################################################################ ################################################################################
def _get_feature_extractor( def _get_feature_extractor(
norm_mode: str, norm_mode: str,
shapes: List[Tuple[int, int, int]], shapes: List[Tuple[int, int, int]],
bias: bool, bias: bool,
) -> FeatureExtractor: ) -> FeatureExtractor:
""" """
Args: Args:
...@@ -545,19 +543,19 @@ def _get_feature_extractor( ...@@ -545,19 +543,19 @@ def _get_feature_extractor(
def _get_encoder( def _get_encoder(
in_features: int, in_features: int,
embed_dim: int, embed_dim: int,
dropout_input: float, dropout_input: float,
pos_conv_kernel: int, pos_conv_kernel: int,
pos_conv_groups: int, pos_conv_groups: int,
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
attention_dropout: float, attention_dropout: float,
ff_interm_features: int, ff_interm_features: int,
ff_interm_dropout: float, ff_interm_dropout: float,
dropout: float, dropout: float,
layer_norm_first: bool, layer_norm_first: bool,
layer_drop: float, layer_drop: float,
) -> Encoder: ) -> Encoder:
""" """
Args: Args:
......
...@@ -26,11 +26,12 @@ class Wav2Vec2Model(Module): ...@@ -26,11 +26,12 @@ class Wav2Vec2Model(Module):
aux (torch.nn.Module or None, optional): aux (torch.nn.Module or None, optional):
Auxiliary module. If provided, the output from encoder is passed to this module. Auxiliary module. If provided, the output from encoder is passed to this module.
""" # noqa: E501 """ # noqa: E501
def __init__( def __init__(
self, self,
feature_extractor: Module, feature_extractor: Module,
encoder: Module, encoder: Module,
aux: Optional[Module] = None, aux: Optional[Module] = None,
): ):
super().__init__() super().__init__()
self.feature_extractor = feature_extractor self.feature_extractor = feature_extractor
...@@ -39,10 +40,10 @@ class Wav2Vec2Model(Module): ...@@ -39,10 +40,10 @@ class Wav2Vec2Model(Module):
@torch.jit.export @torch.jit.export
def extract_features( def extract_features(
self, self,
waveforms: Tensor, waveforms: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None, num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]: ) -> Tuple[List[Tensor], Optional[Tensor]]:
"""Extract feature vectors from raw waveforms """Extract feature vectors from raw waveforms
...@@ -81,9 +82,9 @@ class Wav2Vec2Model(Module): ...@@ -81,9 +82,9 @@ class Wav2Vec2Model(Module):
return x, lengths return x, lengths
def forward( def forward(
self, self,
waveforms: Tensor, waveforms: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
"""Compute the sequence of probability distribution over labels. """Compute the sequence of probability distribution over labels.
...@@ -117,22 +118,22 @@ class Wav2Vec2Model(Module): ...@@ -117,22 +118,22 @@ class Wav2Vec2Model(Module):
def wav2vec2_model( def wav2vec2_model(
extractor_mode: str, extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool, extractor_conv_bias: bool,
encoder_embed_dim: int, encoder_embed_dim: int,
encoder_projection_dropout: float, encoder_projection_dropout: float,
encoder_pos_conv_kernel: int, encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int, encoder_pos_conv_groups: int,
encoder_num_layers: int, encoder_num_layers: int,
encoder_num_heads: int, encoder_num_heads: int,
encoder_attention_dropout: float, encoder_attention_dropout: float,
encoder_ff_interm_features: int, encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float, encoder_ff_interm_dropout: float,
encoder_dropout: float, encoder_dropout: float,
encoder_layer_norm_first: bool, encoder_layer_norm_first: bool,
encoder_layer_drop: float, encoder_layer_drop: float,
aux_num_out: Optional[int], aux_num_out: Optional[int],
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, aux_num_out: Optional[int]) -> torchaudio.models.Wav2Vec2Model """wav2vec2_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, aux_num_out: Optional[int]) -> torchaudio.models.Wav2Vec2Model
...@@ -262,7 +263,8 @@ def wav2vec2_model( ...@@ -262,7 +263,8 @@ def wav2vec2_model(
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
feature_extractor = components._get_feature_extractor( feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias) extractor_mode, extractor_conv_layer_config, extractor_conv_bias
)
encoder = components._get_encoder( encoder = components._get_encoder(
in_features=extractor_conv_layer_config[-1][0], in_features=extractor_conv_layer_config[-1][0],
embed_dim=encoder_embed_dim, embed_dim=encoder_embed_dim,
...@@ -285,12 +287,12 @@ def wav2vec2_model( ...@@ -285,12 +287,12 @@ def wav2vec2_model(
def wav2vec2_base( def wav2vec2_base(
encoder_projection_dropout: float = 0.1, encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1, encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1, encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1, encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model """wav2vec2_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
...@@ -336,12 +338,12 @@ def wav2vec2_base( ...@@ -336,12 +338,12 @@ def wav2vec2_base(
def wav2vec2_large( def wav2vec2_large(
encoder_projection_dropout: float = 0.1, encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1, encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.1, encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.1, encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_large(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model """wav2vec2_large(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
...@@ -387,12 +389,12 @@ def wav2vec2_large( ...@@ -387,12 +389,12 @@ def wav2vec2_large(
def wav2vec2_large_lv60k( def wav2vec2_large_lv60k(
encoder_projection_dropout: float = 0.1, encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.0, encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1,
encoder_dropout: float = 0.0, encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.1, encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""wav2vec2_large_lv60k( encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model """wav2vec2_large_lv60k( encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
...@@ -438,12 +440,12 @@ def wav2vec2_large_lv60k( ...@@ -438,12 +440,12 @@ def wav2vec2_large_lv60k(
def hubert_base( def hubert_base(
encoder_projection_dropout: float = 0.1, encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1, encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1, encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.05, encoder_layer_drop: float = 0.05,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""hubert_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model """hubert_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
...@@ -469,7 +471,7 @@ def hubert_base( ...@@ -469,7 +471,7 @@ def hubert_base(
The resulting model. The resulting model.
""" # noqa: E501 """ # noqa: E501
return wav2vec2_model( return wav2vec2_model(
extractor_mode='group_norm', extractor_mode="group_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=768, encoder_embed_dim=768,
...@@ -489,12 +491,12 @@ def hubert_base( ...@@ -489,12 +491,12 @@ def hubert_base(
def hubert_large( def hubert_large(
encoder_projection_dropout: float = 0.0, encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0, encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0, encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0, encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""hubert_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model """hubert_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
...@@ -520,7 +522,7 @@ def hubert_large( ...@@ -520,7 +522,7 @@ def hubert_large(
The resulting model. The resulting model.
""" # noqa: E501 """ # noqa: E501
return wav2vec2_model( return wav2vec2_model(
extractor_mode='layer_norm', extractor_mode="layer_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1024, encoder_embed_dim=1024,
...@@ -540,12 +542,12 @@ def hubert_large( ...@@ -540,12 +542,12 @@ def hubert_large(
def hubert_xlarge( def hubert_xlarge(
encoder_projection_dropout: float = 0.0, encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0, encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0, encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0, encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""hubert_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model """hubert_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
...@@ -571,7 +573,7 @@ def hubert_xlarge( ...@@ -571,7 +573,7 @@ def hubert_xlarge(
The resulting model. The resulting model.
""" # noqa: E501 """ # noqa: E501
return wav2vec2_model( return wav2vec2_model(
extractor_mode='layer_norm', extractor_mode="layer_norm",
extractor_conv_layer_config=None, extractor_conv_layer_config=None,
extractor_conv_bias=False, extractor_conv_bias=False,
encoder_embed_dim=1280, encoder_embed_dim=1280,
......
from .import_huggingface import import_huggingface_model
from .import_fairseq import import_fairseq_model from .import_fairseq import import_fairseq_model
from .import_huggingface import import_huggingface_model
__all__ = [ __all__ = [
'import_huggingface_model', "import_huggingface_model",
'import_fairseq_model', "import_fairseq_model",
] ]
...@@ -13,11 +13,11 @@ def _parse_config(w2v_model): ...@@ -13,11 +13,11 @@ def _parse_config(w2v_model):
encoder = w2v_model.encoder encoder = w2v_model.encoder
conv_layers = w2v_model.feature_extractor.conv_layers conv_layers = w2v_model.feature_extractor.conv_layers
extractor_mode = 'layer_norm' extractor_mode = "layer_norm"
if 'GroupNorm' in conv_layers[0][2].__class__.__name__: if "GroupNorm" in conv_layers[0][2].__class__.__name__:
extractor_mode = 'group_norm' extractor_mode = "group_norm"
else: else:
extractor_mode = 'layer_norm' extractor_mode = "layer_norm"
conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers] conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers]
...@@ -26,53 +26,52 @@ def _parse_config(w2v_model): ...@@ -26,53 +26,52 @@ def _parse_config(w2v_model):
elif all(l[0].bias is not None for l in conv_layers): elif all(l[0].bias is not None for l in conv_layers):
conv_bias = True conv_bias = True
else: else:
raise ValueError( raise ValueError("Either all the convolutions layers have bias term or none of them should.")
'Either all the convolutions layers have bias term or none of them should.')
config = { config = {
'extractor_mode': extractor_mode, "extractor_mode": extractor_mode,
'extractor_conv_layer_config': conv_layer_config, "extractor_conv_layer_config": conv_layer_config,
'extractor_conv_bias': conv_bias, "extractor_conv_bias": conv_bias,
'encoder_embed_dim': w2v_model.post_extract_proj.out_features, "encoder_embed_dim": w2v_model.post_extract_proj.out_features,
'encoder_projection_dropout': w2v_model.dropout_input.p, "encoder_projection_dropout": w2v_model.dropout_input.p,
'encoder_pos_conv_kernel': encoder.pos_conv[0].kernel_size[0], "encoder_pos_conv_kernel": encoder.pos_conv[0].kernel_size[0],
'encoder_pos_conv_groups': encoder.pos_conv[0].groups, "encoder_pos_conv_groups": encoder.pos_conv[0].groups,
'encoder_num_layers': len(encoder.layers), "encoder_num_layers": len(encoder.layers),
'encoder_num_heads': encoder.layers[0].self_attn.num_heads, "encoder_num_heads": encoder.layers[0].self_attn.num_heads,
'encoder_attention_dropout': encoder.layers[0].self_attn.dropout_module.p, "encoder_attention_dropout": encoder.layers[0].self_attn.dropout_module.p,
'encoder_ff_interm_features': encoder.layers[0].fc1.out_features, "encoder_ff_interm_features": encoder.layers[0].fc1.out_features,
'encoder_ff_interm_dropout': encoder.layers[0].dropout2.p, "encoder_ff_interm_dropout": encoder.layers[0].dropout2.p,
'encoder_dropout': encoder.layers[0].dropout3.p, "encoder_dropout": encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first, "encoder_layer_norm_first": encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop, "encoder_layer_drop": encoder.layerdrop,
} }
return config return config
def _map_key(key): def _map_key(key):
key_ = key key_ = key
if key.startswith('w2v_model.'): if key.startswith("w2v_model."):
key = key.replace('w2v_model.', '') key = key.replace("w2v_model.", "")
if re.match(r'(mask_emb|quantizer|project_q|final_proj|mask_emb)', key): if re.match(r"(mask_emb|quantizer|project_q|final_proj|mask_emb)", key):
return None return None
# Feature Extractor # Feature Extractor
# Group norm when "extractor_mode" is "default". # Group norm when "extractor_mode" is "default".
# (Only the first layer) # (Only the first layer)
# "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight" # "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight"
# "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias" # "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias"
match = re.match(r'feature_extractor\.conv_layers\.0\.2\.(weight|bias)', key) match = re.match(r"feature_extractor\.conv_layers\.0\.2\.(weight|bias)", key)
if match: if match:
return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}" return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}"
# Convolutions # Convolutions
# "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight" # "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight"
# "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias" # "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias"
match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)', key) match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)", key)
if match: if match:
return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}" return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}"
# Layer norm when "extractor_mode" is "layer_norm". # Layer norm when "extractor_mode" is "layer_norm".
# "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight" # "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight"
# "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias" # "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias"
match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)', key) match = re.match(r"feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)", key)
if match: if match:
return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}" return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}"
match = re.match(r"post_extract_proj\.(weight|bias)", key) match = re.match(r"post_extract_proj\.(weight|bias)", key)
...@@ -111,9 +110,9 @@ def _map_key(key): ...@@ -111,9 +110,9 @@ def _map_key(key):
if match: if match:
return f"aux.{match.group(1)}" return f"aux.{match.group(1)}"
# HuBERT Extension # HuBERT Extension
if key in ['label_embs_concat']: if key in ["label_embs_concat"]:
return key return key
raise ValueError(f'Unexpected key: {key_}') raise ValueError(f"Unexpected key: {key_}")
def _convert_state_dict(state_dict): def _convert_state_dict(state_dict):
...@@ -179,16 +178,15 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model: ...@@ -179,16 +178,15 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
.. _fairseq: https://github.com/pytorch/fairseq .. _fairseq: https://github.com/pytorch/fairseq
""" """
class_ = original.__class__.__name__ class_ = original.__class__.__name__
if class_ == 'Wav2Vec2Model': if class_ == "Wav2Vec2Model":
return _import_wav2vec2_pretraining(original) return _import_wav2vec2_pretraining(original)
if class_ == 'Wav2VecEncoder': if class_ == "Wav2VecEncoder":
return _import_wav2vec2_finetuning(original) return _import_wav2vec2_finetuning(original)
if class_ == 'HubertModel': if class_ == "HubertModel":
return _import_hubert_pretraining(original) return _import_hubert_pretraining(original)
if class_ == 'HubertEncoder': if class_ == "HubertEncoder":
return _import_hubert_finetuning(original) return _import_hubert_finetuning(original)
raise ValueError( raise ValueError(f"Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}")
f'Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}')
def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model: def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model:
......
...@@ -11,40 +11,38 @@ _LG = logging.getLogger(__name__) ...@@ -11,40 +11,38 @@ _LG = logging.getLogger(__name__)
def _get_config(cfg): def _get_config(cfg):
config = { config = {
'extractor_mode': f'{cfg.feat_extract_norm}_norm', "extractor_mode": f"{cfg.feat_extract_norm}_norm",
'extractor_conv_layer_config': list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
'extractor_conv_bias': cfg.conv_bias, "extractor_conv_bias": cfg.conv_bias,
'encoder_embed_dim': cfg.hidden_size, "encoder_embed_dim": cfg.hidden_size,
'encoder_projection_dropout': cfg.feat_proj_dropout, "encoder_projection_dropout": cfg.feat_proj_dropout,
'encoder_pos_conv_kernel': cfg.num_conv_pos_embeddings, "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
'encoder_pos_conv_groups': cfg.num_conv_pos_embedding_groups, "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
'encoder_num_layers': cfg.num_hidden_layers, "encoder_num_layers": cfg.num_hidden_layers,
'encoder_num_heads': cfg.num_attention_heads, "encoder_num_heads": cfg.num_attention_heads,
'encoder_attention_dropout': cfg.attention_dropout, "encoder_attention_dropout": cfg.attention_dropout,
'encoder_ff_interm_features': cfg.intermediate_size, "encoder_ff_interm_features": cfg.intermediate_size,
'encoder_ff_interm_dropout': cfg.activation_dropout, "encoder_ff_interm_dropout": cfg.activation_dropout,
'encoder_dropout': cfg.hidden_dropout, "encoder_dropout": cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm, "encoder_layer_norm_first": cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop, "encoder_layer_drop": cfg.layerdrop,
} }
return config return config
def _build(config, original): def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC': if original.__class__.__name__ == "Wav2Vec2ForCTC":
aux_num_out = original.config.vocab_size aux_num_out = original.config.vocab_size
wav2vec2 = original.wav2vec2 wav2vec2 = original.wav2vec2
else: else:
_LG.warning( _LG.warning("The model is not an instance of Wav2Vec2ForCTC. " '"lm_head" module is not imported.')
'The model is not an instance of Wav2Vec2ForCTC. '
'"lm_head" module is not imported.')
aux_num_out = None aux_num_out = None
wav2vec2 = original wav2vec2 = original
imported = wav2vec2_model(**config, aux_num_out=aux_num_out) imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict()) imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC': if original.__class__.__name__ == "Wav2Vec2ForCTC":
imported.aux.load_state_dict(original.lm_head.state_dict()) imported.aux.load_state_dict(original.lm_head.state_dict())
return imported return imported
...@@ -71,10 +69,10 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model: ...@@ -71,10 +69,10 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
.. _Transformers: https://huggingface.co/transformers/ .. _Transformers: https://huggingface.co/transformers/
""" """
_LG.info('Importing model.') _LG.info("Importing model.")
_LG.info('Loading model configuration.') _LG.info("Loading model configuration.")
config = _get_config(original.config) config = _get_config(original.config)
_LG.debug(' - config: %s', config) _LG.debug(" - config: %s", config)
_LG.info('Building model.') _LG.info("Building model.")
imported = _build(config, original) imported = _build(config, original)
return imported return imported
from typing import List, Tuple, Optional
import math import math
from typing import List, Tuple, Optional
import torch import torch
import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
import torch.nn.functional as F
__all__ = [ __all__ = [
"ResBlock", "ResBlock",
...@@ -35,7 +35,7 @@ class ResBlock(nn.Module): ...@@ -35,7 +35,7 @@ class ResBlock(nn.Module):
nn.BatchNorm1d(n_freq), nn.BatchNorm1d(n_freq),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
nn.BatchNorm1d(n_freq) nn.BatchNorm1d(n_freq),
) )
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
...@@ -66,12 +66,9 @@ class MelResNet(nn.Module): ...@@ -66,12 +66,9 @@ class MelResNet(nn.Module):
>>> output = melresnet(input) # shape: (10, 128, 508) >>> output = melresnet(input) # shape: (10, 128, 508)
""" """
def __init__(self, def __init__(
n_res_block: int = 10, self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
n_freq: int = 128, ) -> None:
n_hidden: int = 128,
n_output: int = 128,
kernel_size: int = 5) -> None:
super().__init__() super().__init__()
ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)] ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
...@@ -81,7 +78,7 @@ class MelResNet(nn.Module): ...@@ -81,7 +78,7 @@ class MelResNet(nn.Module):
nn.BatchNorm1d(n_hidden), nn.BatchNorm1d(n_hidden),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
*ResBlocks, *ResBlocks,
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
) )
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
...@@ -110,9 +107,7 @@ class Stretch2d(nn.Module): ...@@ -110,9 +107,7 @@ class Stretch2d(nn.Module):
>>> output = stretch2d(input) # shape: (10, 500, 5120) >>> output = stretch2d(input) # shape: (10, 500, 5120)
""" """
def __init__(self, def __init__(self, time_scale: int, freq_scale: int) -> None:
time_scale: int,
freq_scale: int) -> None:
super().__init__() super().__init__()
self.freq_scale = freq_scale self.freq_scale = freq_scale
...@@ -148,13 +143,15 @@ class UpsampleNetwork(nn.Module): ...@@ -148,13 +143,15 @@ class UpsampleNetwork(nn.Module):
>>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128) >>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128)
""" """
def __init__(self, def __init__(
upsample_scales: List[int], self,
n_res_block: int = 10, upsample_scales: List[int],
n_freq: int = 128, n_res_block: int = 10,
n_hidden: int = 128, n_freq: int = 128,
n_output: int = 128, n_hidden: int = 128,
kernel_size: int = 5) -> None: n_output: int = 128,
kernel_size: int = 5,
) -> None:
super().__init__() super().__init__()
total_scale = 1 total_scale = 1
...@@ -169,12 +166,10 @@ class UpsampleNetwork(nn.Module): ...@@ -169,12 +166,10 @@ class UpsampleNetwork(nn.Module):
up_layers = [] up_layers = []
for scale in upsample_scales: for scale in upsample_scales:
stretch = Stretch2d(scale, 1) stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(in_channels=1, conv = nn.Conv2d(
out_channels=1, in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
kernel_size=(1, scale * 2 + 1), )
padding=(0, scale), torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
bias=False)
torch.nn.init.constant_(conv.weight, 1. / (scale * 2 + 1))
up_layers.append(stretch) up_layers.append(stretch)
up_layers.append(conv) up_layers.append(conv)
self.upsample_layers = nn.Sequential(*up_layers) self.upsample_layers = nn.Sequential(*up_layers)
...@@ -197,7 +192,7 @@ class UpsampleNetwork(nn.Module): ...@@ -197,7 +192,7 @@ class UpsampleNetwork(nn.Module):
specgram = specgram.unsqueeze(1) specgram = specgram.unsqueeze(1)
upsampling_output = self.upsample_layers(specgram) upsampling_output = self.upsample_layers(specgram)
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
return upsampling_output, resnet_output return upsampling_output, resnet_output
...@@ -230,17 +225,19 @@ class WaveRNN(nn.Module): ...@@ -230,17 +225,19 @@ class WaveRNN(nn.Module):
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes) >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
""" """
def __init__(self, def __init__(
upsample_scales: List[int], self,
n_classes: int, upsample_scales: List[int],
hop_length: int, n_classes: int,
n_res_block: int = 10, hop_length: int,
n_rnn: int = 512, n_res_block: int = 10,
n_fc: int = 512, n_rnn: int = 512,
kernel_size: int = 5, n_fc: int = 512,
n_freq: int = 128, kernel_size: int = 5,
n_hidden: int = 128, n_freq: int = 128,
n_output: int = 128) -> None: n_hidden: int = 128,
n_output: int = 128,
) -> None:
super().__init__() super().__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
...@@ -257,12 +254,7 @@ class WaveRNN(nn.Module): ...@@ -257,12 +254,7 @@ class WaveRNN(nn.Module):
if total_scale != self.hop_length: if total_scale != self.hop_length:
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
self.upsample = UpsampleNetwork(upsample_scales, self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
n_res_block,
n_freq,
n_hidden,
n_output,
kernel_size)
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
...@@ -286,8 +278,8 @@ class WaveRNN(nn.Module): ...@@ -286,8 +278,8 @@ class WaveRNN(nn.Module):
Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
""" """
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' assert waveform.size(1) == 1, "Require the input channel of waveform is 1"
assert specgram.size(1) == 1, 'Require the input channel of specgram is 1' assert specgram.size(1) == 1, "Require the input channel of specgram is 1"
# remove channel dimension until the end # remove channel dimension until the end
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
...@@ -302,10 +294,10 @@ class WaveRNN(nn.Module): ...@@ -302,10 +294,10 @@ class WaveRNN(nn.Module):
aux = aux.transpose(1, 2) aux = aux.transpose(1, 2)
aux_idx = [self.n_aux * i for i in range(5)] aux_idx = [self.n_aux * i for i in range(5)]
a1 = aux[:, :, aux_idx[0]:aux_idx[1]] a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
a2 = aux[:, :, aux_idx[1]:aux_idx[2]] a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
a3 = aux[:, :, aux_idx[2]:aux_idx[3]] a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
a4 = aux[:, :, aux_idx[3]:aux_idx[4]] a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
x = self.fc(x) x = self.fc(x)
...@@ -375,7 +367,7 @@ class WaveRNN(nn.Module): ...@@ -375,7 +367,7 @@ class WaveRNN(nn.Module):
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype) x = torch.zeros((b_size, 1), device=device, dtype=dtype)
aux_split = [aux[:, self.n_aux * i: self.n_aux * (i + 1), :] for i in range(4)] aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
for i in range(seq_len): for i in range(seq_len):
......
from ._tts import (
Tacotron2TTSBundle,
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
from ._wav2vec2.impl import ( from ._wav2vec2.impl import (
Wav2Vec2Bundle, Wav2Vec2Bundle,
Wav2Vec2ASRBundle, Wav2Vec2ASRBundle,
...@@ -25,43 +32,36 @@ from ._wav2vec2.impl import ( ...@@ -25,43 +32,36 @@ from ._wav2vec2.impl import (
HUBERT_ASR_LARGE, HUBERT_ASR_LARGE,
HUBERT_ASR_XLARGE, HUBERT_ASR_XLARGE,
) )
from ._tts import (
Tacotron2TTSBundle,
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
)
__all__ = [ __all__ = [
'Wav2Vec2Bundle', "Wav2Vec2Bundle",
'Wav2Vec2ASRBundle', "Wav2Vec2ASRBundle",
'WAV2VEC2_BASE', "WAV2VEC2_BASE",
'WAV2VEC2_LARGE', "WAV2VEC2_LARGE",
'WAV2VEC2_LARGE_LV60K', "WAV2VEC2_LARGE_LV60K",
'WAV2VEC2_ASR_BASE_10M', "WAV2VEC2_ASR_BASE_10M",
'WAV2VEC2_ASR_BASE_100H', "WAV2VEC2_ASR_BASE_100H",
'WAV2VEC2_ASR_BASE_960H', "WAV2VEC2_ASR_BASE_960H",
'WAV2VEC2_ASR_LARGE_10M', "WAV2VEC2_ASR_LARGE_10M",
'WAV2VEC2_ASR_LARGE_100H', "WAV2VEC2_ASR_LARGE_100H",
'WAV2VEC2_ASR_LARGE_960H', "WAV2VEC2_ASR_LARGE_960H",
'WAV2VEC2_ASR_LARGE_LV60K_10M', "WAV2VEC2_ASR_LARGE_LV60K_10M",
'WAV2VEC2_ASR_LARGE_LV60K_100H', "WAV2VEC2_ASR_LARGE_LV60K_100H",
'WAV2VEC2_ASR_LARGE_LV60K_960H', "WAV2VEC2_ASR_LARGE_LV60K_960H",
'WAV2VEC2_XLSR53', "WAV2VEC2_XLSR53",
'VOXPOPULI_ASR_BASE_10K_EN', "VOXPOPULI_ASR_BASE_10K_EN",
'VOXPOPULI_ASR_BASE_10K_ES', "VOXPOPULI_ASR_BASE_10K_ES",
'VOXPOPULI_ASR_BASE_10K_DE', "VOXPOPULI_ASR_BASE_10K_DE",
'VOXPOPULI_ASR_BASE_10K_FR', "VOXPOPULI_ASR_BASE_10K_FR",
'VOXPOPULI_ASR_BASE_10K_IT', "VOXPOPULI_ASR_BASE_10K_IT",
'HUBERT_BASE', "HUBERT_BASE",
'HUBERT_LARGE', "HUBERT_LARGE",
'HUBERT_XLARGE', "HUBERT_XLARGE",
'HUBERT_ASR_LARGE', "HUBERT_ASR_LARGE",
'HUBERT_ASR_XLARGE', "HUBERT_ASR_XLARGE",
'Tacotron2TTSBundle', "Tacotron2TTSBundle",
'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH', "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH', "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
'TACOTRON2_WAVERNN_CHAR_LJSPEECH', "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
'TACOTRON2_WAVERNN_PHONE_LJSPEECH', "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment