"...text-generation-inference.git" did not exist on "c6e8b9442b1fcf7bbbe4be58fcd85047f69e4112"
Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
...@@ -21,11 +21,7 @@ def _parse_args(): ...@@ -21,11 +21,7 @@ def _parse_args():
description=__doc__, description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument( parser.add_argument("input_dir", type=Path, help="Directory where `*.trans.txt` files are searched.")
'input_dir',
type=Path,
help='Directory where `*.trans.txt` files are searched.'
)
return parser.parse_args() return parser.parse_args()
...@@ -34,22 +30,22 @@ def _parse_transcript(path): ...@@ -34,22 +30,22 @@ def _parse_transcript(path):
for line in trans_fileobj: for line in trans_fileobj:
line = line.strip() line = line.strip()
if line: if line:
yield line.split(' ', maxsplit=1) yield line.split(" ", maxsplit=1)
def _parse_directory(root_dir: Path): def _parse_directory(root_dir: Path):
for trans_file in root_dir.glob('**/*.trans.txt'): for trans_file in root_dir.glob("**/*.trans.txt"):
trans_dir = trans_file.parent trans_dir = trans_file.parent
for id_, transcription in _parse_transcript(trans_file): for id_, transcription in _parse_transcript(trans_file):
audio_path = trans_dir / f'{id_}.flac' audio_path = trans_dir / f"{id_}.flac"
yield id_, audio_path, transcription yield id_, audio_path, transcription
def _main(): def _main():
args = _parse_args() args = _parse_args()
for id_, path, transcription in _parse_directory(args.input_dir): for id_, path, transcription in _parse_directory(args.input_dir):
print(f'{id_}\t{path}\t{transcription}') print(f"{id_}\t{path}\t{transcription}")
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -12,8 +12,8 @@ example: python parse_voxforge.py voxforge/de/Helge-20150608-aku ...@@ -12,8 +12,8 @@ example: python parse_voxforge.py voxforge/de/Helge-20150608-aku
Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/ Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/
""" # noqa: E501 """ # noqa: E501
import os
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
...@@ -22,11 +22,7 @@ def _parse_args(): ...@@ -22,11 +22,7 @@ def _parse_args():
description=__doc__, description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument( parser.add_argument("input_dir", type=Path, help="Directory where `*.trans.txt` files are searched.")
'input_dir',
type=Path,
help='Directory where `*.trans.txt` files are searched.'
)
return parser.parse_args() return parser.parse_args()
...@@ -38,19 +34,19 @@ def _parse_prompts(path): ...@@ -38,19 +34,19 @@ def _parse_prompts(path):
if not line: if not line:
continue continue
id_, transcript = line.split(' ', maxsplit=1) id_, transcript = line.split(" ", maxsplit=1)
if not transcript: if not transcript:
continue continue
transcript = transcript.upper() transcript = transcript.upper()
filename = id_.split('/')[-1] filename = id_.split("/")[-1]
audio_path = base_dir / 'wav' / f'{filename}.wav' audio_path = base_dir / "wav" / f"{filename}.wav"
if os.path.exists(audio_path): if os.path.exists(audio_path):
yield id_, audio_path, transcript yield id_, audio_path, transcript
def _parse_directory(root_dir: Path): def _parse_directory(root_dir: Path):
for prompt_file in root_dir.glob('**/PROMPTS'): for prompt_file in root_dir.glob("**/PROMPTS"):
try: try:
yield from _parse_prompts(prompt_file) yield from _parse_prompts(prompt_file)
except UnicodeDecodeError: except UnicodeDecodeError:
...@@ -60,8 +56,8 @@ def _parse_directory(root_dir: Path): ...@@ -60,8 +56,8 @@ def _parse_directory(root_dir: Path):
def _main(): def _main():
args = _parse_args() args = _parse_args()
for id_, path, transcription in _parse_directory(args.input_dir): for id_, path, transcription in _parse_directory(args.input_dir):
print(f'{id_}\t{path}\t{transcription}') print(f"{id_}\t{path}\t{transcription}")
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -29,7 +29,6 @@ from typing import Tuple, Callable, List ...@@ -29,7 +29,6 @@ from typing import Tuple, Callable, List
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils.data.dataset import random_split from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH from torchaudio.datasets import LJSPEECH
...@@ -45,8 +44,7 @@ class InverseSpectralNormalization(torch.nn.Module): ...@@ -45,8 +44,7 @@ class InverseSpectralNormalization(torch.nn.Module):
class MapMemoryCache(torch.utils.data.Dataset): class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""
"""
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
...@@ -84,16 +82,17 @@ class Processed(torch.utils.data.Dataset): ...@@ -84,16 +82,17 @@ class Processed(torch.utils.data.Dataset):
return text_norm, torch.squeeze(melspec, 0) return text_norm, torch.squeeze(melspec, 0)
def split_process_dataset(dataset: str, def split_process_dataset(
file_path: str, dataset: str,
val_ratio: float, file_path: str,
transforms: Callable, val_ratio: float,
text_preprocessor: Callable[[str], List[int]], transforms: Callable,
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]: text_preprocessor: Callable[[str], List[int]],
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
"""Returns the Training and validation datasets. """Returns the Training and validation datasets.
Args: Args:
dataset (str): The dataset to use. Avaliable options: [`'ljspeech'`] dataset (str): The dataset to use. Available options: [`'ljspeech'`]
file_path (str): Path to the data. file_path (str): Path to the data.
val_ratio (float): Path to the data. val_ratio (float): Path to the data.
transforms (callable): A function/transform that takes in a waveform and transforms (callable): A function/transform that takes in a waveform and
...@@ -105,7 +104,7 @@ def split_process_dataset(dataset: str, ...@@ -105,7 +104,7 @@ def split_process_dataset(dataset: str,
train_dataset (`torch.utils.data.Dataset`): The training set. train_dataset (`torch.utils.data.Dataset`): The training set.
val_dataset (`torch.utils.data.Dataset`): The validation set. val_dataset (`torch.utils.data.Dataset`): The validation set.
""" """
if dataset == 'ljspeech': if dataset == "ljspeech":
data = LJSPEECH(root=file_path, download=False) data = LJSPEECH(root=file_path, download=False)
val_length = int(len(data) * val_ratio) val_length = int(len(data) * val_ratio)
...@@ -123,8 +122,9 @@ def split_process_dataset(dataset: str, ...@@ -123,8 +122,9 @@ def split_process_dataset(dataset: str,
return train_dataset, val_dataset return train_dataset, val_dataset
def text_mel_collate_fn(batch: Tuple[Tensor, Tensor], def text_mel_collate_fn(
n_frames_per_step: int = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: batch: Tuple[Tensor, Tensor], n_frames_per_step: int = 1
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""The collate function padding and adjusting the data based on `n_frames_per_step`. """The collate function padding and adjusting the data based on `n_frames_per_step`.
Modified from https://github.com/NVIDIA/DeepLearningExamples Modified from https://github.com/NVIDIA/DeepLearningExamples
...@@ -143,13 +143,14 @@ def text_mel_collate_fn(batch: Tuple[Tensor, Tensor], ...@@ -143,13 +143,14 @@ def text_mel_collate_fn(batch: Tuple[Tensor, Tensor],
with shape (n_batch, max of ``mel_specgram_lengths``) with shape (n_batch, max of ``mel_specgram_lengths``)
""" """
text_lengths, ids_sorted_decreasing = torch.sort( text_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True) torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
)
max_input_len = text_lengths[0] max_input_len = text_lengths[0]
text_padded = torch.zeros((len(batch), max_input_len), dtype=torch.int64) text_padded = torch.zeros((len(batch), max_input_len), dtype=torch.int64)
for i in range(len(ids_sorted_decreasing)): for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]][0] text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text text_padded[i, : text.size(0)] = text
# Right zero-pad mel-spec # Right zero-pad mel-spec
num_mels = batch[0][1].size(0) num_mels = batch[0][1].size(0)
...@@ -164,8 +165,8 @@ def text_mel_collate_fn(batch: Tuple[Tensor, Tensor], ...@@ -164,8 +165,8 @@ def text_mel_collate_fn(batch: Tuple[Tensor, Tensor],
mel_specgram_lengths = torch.LongTensor(len(batch)) mel_specgram_lengths = torch.LongTensor(len(batch))
for i in range(len(ids_sorted_decreasing)): for i in range(len(ids_sorted_decreasing)):
mel = batch[ids_sorted_decreasing[i]][1] mel = batch[ids_sorted_decreasing[i]][1]
mel_specgram_padded[i, :, :mel.size(1)] = mel mel_specgram_padded[i, :, : mel.size(1)] = mel
mel_specgram_lengths[i] = mel.size(1) mel_specgram_lengths[i] = mel.size(1)
gate_padded[i, mel.size(1) - 1:] = 1 gate_padded[i, mel.size(1) - 1 :] = 1
return text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths, gate_padded return text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths, gate_padded
...@@ -2,19 +2,15 @@ ...@@ -2,19 +2,15 @@
Text-to-speech pipeline using Tacotron2. Text-to-speech pipeline using Tacotron2.
""" """
from functools import partial
import argparse import argparse
import os import os
import random import random
import sys import sys
from functools import partial
import numpy as np
import torch import torch
import torchaudio import torchaudio
import numpy as np
from torchaudio.models import Tacotron2
from torchaudio.models import tacotron2 as pretrained_tacotron2
from utils import prepare_input_sequence
from datasets import InverseSpectralNormalization from datasets import InverseSpectralNormalization
from text.text_preprocessing import ( from text.text_preprocessing import (
available_symbol_set, available_symbol_set,
...@@ -22,6 +18,9 @@ from text.text_preprocessing import ( ...@@ -22,6 +18,9 @@ from text.text_preprocessing import (
get_symbol_list, get_symbol_list,
text_to_sequence, text_to_sequence,
) )
from torchaudio.models import Tacotron2
from torchaudio.models import tacotron2 as pretrained_tacotron2
from utils import prepare_input_sequence
def parse_args(): def parse_args():
...@@ -33,113 +32,72 @@ def parse_args(): ...@@ -33,113 +32,72 @@ def parse_args():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--checkpoint-name', "--checkpoint-name",
type=str, type=str,
default=None, default=None,
choices=list(tacotron2_config_and_urls.keys()), choices=list(tacotron2_config_and_urls.keys()),
help='[string] The name of the checkpoint to load.' help="[string] The name of the checkpoint to load.",
)
parser.add_argument(
'--checkpoint-path',
type=str,
default=None,
help='[string] Path to the checkpoint file.'
)
parser.add_argument(
'--output-path',
type=str,
default="./audio.wav",
help='[string] Path to the output .wav file.'
) )
parser.add_argument("--checkpoint-path", type=str, default=None, help="[string] Path to the checkpoint file.")
parser.add_argument("--output-path", type=str, default="./audio.wav", help="[string] Path to the output .wav file.")
parser.add_argument( parser.add_argument(
'--input-text', "--input-text",
'-i', "-i",
type=str, type=str,
default="Hello world", default="Hello world",
help='[string] Type in something here and TTS will generate it!' help="[string] Type in something here and TTS will generate it!",
) )
parser.add_argument( parser.add_argument(
'--vocoder', "--vocoder",
default='nvidia_waveglow', default="nvidia_waveglow",
choices=['griffin_lim', 'wavernn', 'nvidia_waveglow'], choices=["griffin_lim", "wavernn", "nvidia_waveglow"],
type=str, type=str,
help="Select the vocoder to use.", help="Select the vocoder to use.",
) )
parser.add_argument( parser.add_argument(
"--jit", "--jit", default=False, action="store_true", help="If used, the model and inference function is jitted."
default=False,
action="store_true",
help="If used, the model and inference function is jitted."
) )
preprocessor = parser.add_argument_group('text preprocessor setup') preprocessor = parser.add_argument_group("text preprocessor setup")
preprocessor.add_argument( preprocessor.add_argument(
'--text-preprocessor', "--text-preprocessor",
default='english_characters', default="english_characters",
type=str, type=str,
choices=available_symbol_set, choices=available_symbol_set,
help='select text preprocessor to use.' help="select text preprocessor to use.",
) )
preprocessor.add_argument( preprocessor.add_argument(
'--phonemizer', "--phonemizer",
default="DeepPhonemizer", default="DeepPhonemizer",
type=str, type=str,
choices=available_phonemizers, choices=available_phonemizers,
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"' help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"',
) )
preprocessor.add_argument( preprocessor.add_argument(
'--phonemizer-checkpoint', "--phonemizer-checkpoint",
default="./en_us_cmudict_forward.pt", default="./en_us_cmudict_forward.pt",
type=str, type=str,
help='the path or name of the checkpoint for the phonemizer, ' help="the path or name of the checkpoint for the phonemizer, "
'only used when text-preprocessor is "english_phonemes"' 'only used when text-preprocessor is "english_phonemes"',
) )
preprocessor.add_argument( preprocessor.add_argument(
'--cmudict-root', "--cmudict-root", default="./", type=str, help="the root directory for storing CMU dictionary files"
default="./",
type=str,
help='the root directory for storing CMU dictionary files'
) )
audio = parser.add_argument_group('audio parameters') audio = parser.add_argument_group("audio parameters")
audio.add_argument( audio.add_argument("--sample-rate", default=22050, type=int, help="Sampling rate")
'--sample-rate', audio.add_argument("--n-fft", default=1024, type=int, help="Filter length for STFT")
default=22050, audio.add_argument("--n-mels", default=80, type=int, help="")
type=int, audio.add_argument("--mel-fmin", default=0.0, type=float, help="Minimum mel frequency")
help='Sampling rate' audio.add_argument("--mel-fmax", default=8000.0, type=float, help="Maximum mel frequency")
)
audio.add_argument(
'--n-fft',
default=1024,
type=int,
help='Filter length for STFT'
)
audio.add_argument(
'--n-mels',
default=80,
type=int,
help=''
)
audio.add_argument(
'--mel-fmin',
default=0.0,
type=float,
help='Minimum mel frequency'
)
audio.add_argument(
'--mel-fmax',
default=8000.0,
type=float,
help='Maximum mel frequency'
)
# parameters for WaveRNN # parameters for WaveRNN
wavernn = parser.add_argument_group('WaveRNN parameters') wavernn = parser.add_argument_group("WaveRNN parameters")
wavernn.add_argument( wavernn.add_argument(
'--wavernn-checkpoint-name', "--wavernn-checkpoint-name",
default="wavernn_10k_epochs_8bits_ljspeech", default="wavernn_10k_epochs_8bits_ljspeech",
choices=list(wavernn_config_and_urls.keys()), choices=list(wavernn_config_and_urls.keys()),
help="Select the WaveRNN checkpoint." help="Select the WaveRNN checkpoint.",
) )
wavernn.add_argument( wavernn.add_argument(
"--wavernn-loss", "--wavernn-loss",
...@@ -152,13 +110,10 @@ def parse_args(): ...@@ -152,13 +110,10 @@ def parse_args():
"--wavernn-no-batch-inference", "--wavernn-no-batch-inference",
default=False, default=False,
action="store_true", action="store_true",
help="Don't use batch inference for WaveRNN inference." help="Don't use batch inference for WaveRNN inference.",
) )
wavernn.add_argument( wavernn.add_argument(
"--wavernn-no-mulaw", "--wavernn-no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decode the signal."
default=False,
action="store_true",
help="Don't use mulaw decoder to decode the signal."
) )
wavernn.add_argument( wavernn.add_argument(
"--wavernn-batch-timesteps", "--wavernn-batch-timesteps",
...@@ -187,11 +142,11 @@ def unwrap_distributed(state_dict): ...@@ -187,11 +142,11 @@ def unwrap_distributed(state_dict):
unwrapped_state_dict: Unwrapped state_dict. unwrapped_state_dict: Unwrapped state_dict.
""" """
return {k.replace('module.', ''): v for k, v in state_dict.items()} return {k.replace("module.", ""): v for k, v in state_dict.items()}
def nvidia_waveglow_vocode(mel_specgram, device, jit=False): def nvidia_waveglow_vocode(mel_specgram, device, jit=False):
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp16') waveglow = torch.hub.load("NVIDIA/DeepLearningExamples:torchhub", "nvidia_waveglow", model_math="fp16")
waveglow = waveglow.remove_weightnorm(waveglow) waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device) waveglow = waveglow.to(device)
waveglow.eval() waveglow.eval()
...@@ -205,13 +160,22 @@ def nvidia_waveglow_vocode(mel_specgram, device, jit=False): ...@@ -205,13 +160,22 @@ def nvidia_waveglow_vocode(mel_specgram, device, jit=False):
return waveform return waveform
def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_no_mulaw, def wavernn_vocode(
wavernn_no_batch_inference, wavernn_batch_timesteps, wavernn_batch_overlap, mel_specgram,
device, jit): wavernn_checkpoint_name,
wavernn_loss,
wavernn_no_mulaw,
wavernn_no_batch_inference,
wavernn_batch_timesteps,
wavernn_batch_overlap,
device,
jit,
):
from torchaudio.models import wavernn from torchaudio.models import wavernn
sys.path.append(os.path.join(os.path.dirname(__file__), "../pipeline_wavernn")) sys.path.append(os.path.join(os.path.dirname(__file__), "../pipeline_wavernn"))
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
from processing import NormalizeDB from processing import NormalizeDB
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
wavernn_model = wavernn(wavernn_checkpoint_name).eval().to(device) wavernn_model = wavernn(wavernn_checkpoint_name).eval().to(device)
wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model) wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
...@@ -234,16 +198,26 @@ def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_ ...@@ -234,16 +198,26 @@ def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_
mel_specgram = transforms(mel_specgram.cpu()) mel_specgram = transforms(mel_specgram.cpu())
with torch.no_grad(): with torch.no_grad():
waveform = wavernn_inference_model(mel_specgram.to(device), waveform = wavernn_inference_model(
loss_name=wavernn_loss, mel_specgram.to(device),
mulaw=(not wavernn_no_mulaw), loss_name=wavernn_loss,
batched=(not wavernn_no_batch_inference), mulaw=(not wavernn_no_mulaw),
timesteps=wavernn_batch_timesteps, batched=(not wavernn_no_batch_inference),
overlap=wavernn_batch_overlap,) timesteps=wavernn_batch_timesteps,
overlap=wavernn_batch_overlap,
)
return waveform.unsqueeze(0) return waveform.unsqueeze(0)
def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_fmax, jit, ): def griffin_lim_vocode(
mel_specgram,
n_fft,
n_mels,
sample_rate,
mel_fmin,
mel_fmax,
jit,
):
from torchaudio.transforms import GriffinLim, InverseMelScale from torchaudio.transforms import GriffinLim, InverseMelScale
inv_norm = InverseSpectralNormalization() inv_norm = InverseSpectralNormalization()
...@@ -254,7 +228,7 @@ def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_f ...@@ -254,7 +228,7 @@ def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_f
f_min=mel_fmin, f_min=mel_fmin,
f_max=mel_fmax, f_max=mel_fmax,
mel_scale="slaney", mel_scale="slaney",
norm='slaney', norm="slaney",
) )
griffin_lim = GriffinLim( griffin_lim = GriffinLim(
n_fft=n_fft, n_fft=n_fft,
...@@ -263,11 +237,7 @@ def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_f ...@@ -263,11 +237,7 @@ def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_f
win_length=1024, win_length=1024,
) )
vocoder = torch.nn.Sequential( vocoder = torch.nn.Sequential(inv_norm, inv_mel, griffin_lim)
inv_norm,
inv_mel,
griffin_lim
)
if jit: if jit:
vocoder = torch.jit.script(vocoder) vocoder = torch.jit.script(vocoder)
...@@ -286,8 +256,7 @@ def main(args): ...@@ -286,8 +256,7 @@ def main(args):
if args.checkpoint_path is None and args.checkpoint_name is None: if args.checkpoint_path is None and args.checkpoint_name is None:
raise ValueError("Either --checkpoint-path or --checkpoint-name must be specified.") raise ValueError("Either --checkpoint-path or --checkpoint-name must be specified.")
elif args.checkpoint_path is not None and args.checkpoint_name is not None: elif args.checkpoint_path is not None and args.checkpoint_name is not None:
raise ValueError("Both --checkpoint-path and --checkpoint-name are specified, " raise ValueError("Both --checkpoint-path and --checkpoint-name are specified, " "can only specify one.")
"can only specify one.")
n_symbols = len(get_symbol_list(args.text_preprocessor)) n_symbols = len(get_symbol_list(args.text_preprocessor))
text_preprocessor = partial( text_preprocessor = partial(
...@@ -301,21 +270,23 @@ def main(args): ...@@ -301,21 +270,23 @@ def main(args):
if args.checkpoint_path is not None: if args.checkpoint_path is not None:
tacotron2 = Tacotron2(n_symbol=n_symbols) tacotron2 = Tacotron2(n_symbol=n_symbols)
tacotron2.load_state_dict( tacotron2.load_state_dict(
unwrap_distributed(torch.load(args.checkpoint_path, map_location=device)['state_dict'])) unwrap_distributed(torch.load(args.checkpoint_path, map_location=device)["state_dict"])
)
tacotron2 = tacotron2.to(device).eval() tacotron2 = tacotron2.to(device).eval()
elif args.checkpoint_name is not None: elif args.checkpoint_name is not None:
tacotron2 = pretrained_tacotron2(args.checkpoint_name).to(device).eval() tacotron2 = pretrained_tacotron2(args.checkpoint_name).to(device).eval()
if n_symbols != tacotron2.n_symbols: if n_symbols != tacotron2.n_symbols:
raise ValueError("the number of symbols for text_preprocessor ({n_symbols}) " raise ValueError(
"should match the number of symbols for the" "the number of symbols for text_preprocessor ({n_symbols}) "
"pretrained tacotron2 ({tacotron2.n_symbols}).") "should match the number of symbols for the"
"pretrained tacotron2 ({tacotron2.n_symbols})."
)
if args.jit: if args.jit:
tacotron2 = torch.jit.script(tacotron2) tacotron2 = torch.jit.script(tacotron2)
sequences, lengths = prepare_input_sequence([args.input_text], sequences, lengths = prepare_input_sequence([args.input_text], text_processor=text_preprocessor)
text_processor=text_preprocessor)
sequences, lengths = sequences.long().to(device), lengths.long().to(device) sequences, lengths = sequences.long().to(device), lengths.long().to(device)
with torch.no_grad(): with torch.no_grad():
mel_specgram, _, _ = tacotron2.infer(sequences, lengths) mel_specgram, _, _ = tacotron2.infer(sequences, lengths)
...@@ -324,24 +295,28 @@ def main(args): ...@@ -324,24 +295,28 @@ def main(args):
waveform = nvidia_waveglow_vocode(mel_specgram=mel_specgram, device=device, jit=args.jit) waveform = nvidia_waveglow_vocode(mel_specgram=mel_specgram, device=device, jit=args.jit)
elif args.vocoder == "wavernn": elif args.vocoder == "wavernn":
waveform = wavernn_vocode(mel_specgram=mel_specgram, waveform = wavernn_vocode(
wavernn_checkpoint_name=args.wavernn_checkpoint_name, mel_specgram=mel_specgram,
wavernn_loss=args.wavernn_loss, wavernn_checkpoint_name=args.wavernn_checkpoint_name,
wavernn_no_mulaw=args.wavernn_no_mulaw, wavernn_loss=args.wavernn_loss,
wavernn_no_batch_inference=args.wavernn_no_batch_inference, wavernn_no_mulaw=args.wavernn_no_mulaw,
wavernn_batch_timesteps=args.wavernn_batch_timesteps, wavernn_no_batch_inference=args.wavernn_no_batch_inference,
wavernn_batch_overlap=args.wavernn_batch_overlap, wavernn_batch_timesteps=args.wavernn_batch_timesteps,
device=device, wavernn_batch_overlap=args.wavernn_batch_overlap,
jit=args.jit) device=device,
jit=args.jit,
)
elif args.vocoder == "griffin_lim": elif args.vocoder == "griffin_lim":
waveform = griffin_lim_vocode(mel_specgram=mel_specgram, waveform = griffin_lim_vocode(
n_fft=args.n_fft, mel_specgram=mel_specgram,
n_mels=args.n_mels, n_fft=args.n_fft,
sample_rate=args.sample_rate, n_mels=args.n_mels,
mel_fmin=args.mel_fmin, sample_rate=args.sample_rate,
mel_fmax=args.mel_fmax, mel_fmin=args.mel_fmin,
jit=args.jit) mel_fmax=args.mel_fmax,
jit=args.jit,
)
torchaudio.save(args.output_path, waveform, args.sample_rate) torchaudio.save(args.output_path, waveform, args.sample_rate)
......
...@@ -24,33 +24,34 @@ ...@@ -24,33 +24,34 @@
Modified from https://github.com/keithito/tacotron Modified from https://github.com/keithito/tacotron
""" """
import inflect
import re import re
import inflect
_inflect = inflect.engine() _inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r'[0-9]+') _number_re = re.compile(r"[0-9]+")
def _remove_commas(text: str) -> str: def _remove_commas(text: str) -> str:
return re.sub(_comma_number_re, lambda m: m.group(1).replace(',', ''), text) return re.sub(_comma_number_re, lambda m: m.group(1).replace(",", ""), text)
def _expand_pounds(text: str) -> str: def _expand_pounds(text: str) -> str:
return re.sub(_pounds_re, r'\1 pounds', text) return re.sub(_pounds_re, r"\1 pounds", text)
def _expand_dollars_repl_fn(m): def _expand_dollars_repl_fn(m):
"""The replacement function for expanding dollars.""" """The replacement function for expanding dollars."""
match = m.group(1) match = m.group(1)
parts = match.split('.') parts = match.split(".")
if len(parts) > 2: if len(parts) > 2:
return match + ' dollars' # Unexpected format return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0 dollars = int(parts[0]) if parts[0] else 0
if len(parts) > 1 and parts[1]: if len(parts) > 1 and parts[1]:
if len(parts[1]) == 1: if len(parts[1]) == 1:
...@@ -61,17 +62,17 @@ def _expand_dollars_repl_fn(m): ...@@ -61,17 +62,17 @@ def _expand_dollars_repl_fn(m):
else: else:
cents = 0 cents = 0
if dollars and cents: if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = "cent" if cents == 1 else "cents"
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars: elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = "dollar" if dollars == 1 else "dollars"
return '%s %s' % (dollars, dollar_unit) return "%s %s" % (dollars, dollar_unit)
elif cents: elif cents:
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = "cent" if cents == 1 else "cents"
return '%s %s' % (cents, cent_unit) return "%s %s" % (cents, cent_unit)
else: else:
return 'zero dollars' return "zero dollars"
def _expand_dollars(text: str) -> str: def _expand_dollars(text: str) -> str:
...@@ -79,7 +80,7 @@ def _expand_dollars(text: str) -> str: ...@@ -79,7 +80,7 @@ def _expand_dollars(text: str) -> str:
def _expand_decimal_point(text: str) -> str: def _expand_decimal_point(text: str) -> str:
return re.sub(_decimal_number_re, lambda m: m.group(1).replace('.', ' point '), text) return re.sub(_decimal_number_re, lambda m: m.group(1).replace(".", " point "), text)
def _expand_ordinal(text: str) -> str: def _expand_ordinal(text: str) -> str:
...@@ -91,15 +92,15 @@ def _expand_number_repl_fn(m): ...@@ -91,15 +92,15 @@ def _expand_number_repl_fn(m):
num = int(m.group(0)) num = int(m.group(0))
if num > 1000 and num < 3000: if num > 1000 and num < 3000:
if num == 2000: if num == 2000:
return 'two thousand' return "two thousand"
elif num > 2000 and num < 2010: elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100) return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred' return _inflect.number_to_words(num // 100) + " hundred"
else: else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else: else:
return _inflect.number_to_words(num, andword='') return _inflect.number_to_words(num, andword="")
def _expand_number(text: str) -> str: def _expand_number(text: str) -> str:
......
...@@ -24,44 +24,47 @@ ...@@ -24,44 +24,47 @@
Modified from https://github.com/keithito/tacotron Modified from https://github.com/keithito/tacotron
""" """
from typing import List, Union, Optional
import re import re
from typing import List, Union, Optional
from unidecode import unidecode
from torchaudio.datasets import CMUDict from torchaudio.datasets import CMUDict
from unidecode import unidecode
from .numbers import normalize_numbers from .numbers import normalize_numbers
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+') _whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ _abbreviations = [
('mrs', 'misess'), (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
('mr', 'mister'), for x in [
('dr', 'doctor'), ("mrs", "misess"),
('st', 'saint'), ("mr", "mister"),
('co', 'company'), ("dr", "doctor"),
('jr', 'junior'), ("st", "saint"),
('maj', 'major'), ("co", "company"),
('gen', 'general'), ("jr", "junior"),
('drs', 'doctors'), ("maj", "major"),
('rev', 'reverend'), ("gen", "general"),
('lt', 'lieutenant'), ("drs", "doctors"),
('hon', 'honorable'), ("rev", "reverend"),
('sgt', 'sergeant'), ("lt", "lieutenant"),
('capt', 'captain'), ("hon", "honorable"),
('esq', 'esquire'), ("sgt", "sergeant"),
('ltd', 'limited'), ("capt", "captain"),
('col', 'colonel'), ("esq", "esquire"),
('ft', 'fort'), ("ltd", "limited"),
]] ("col", "colonel"),
("ft", "fort"),
_pad = '_' ]
_punctuation = '!\'(),.:;? ' ]
_special = '-'
_letters = 'abcdefghijklmnopqrstuvwxyz' _pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "abcdefghijklmnopqrstuvwxyz"
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None _phonemizer = None
...@@ -71,23 +74,25 @@ available_symbol_set = set(["english_characters", "english_phonemes"]) ...@@ -71,23 +74,25 @@ available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"]) available_phonemizers = set(["DeepPhonemizer"])
def get_symbol_list(symbol_list: str = "english_characters", def get_symbol_list(symbol_list: str = "english_characters", cmudict_root: Optional[str] = "./") -> List[str]:
cmudict_root: Optional[str] = "./") -> List[str]:
if symbol_list == "english_characters": if symbol_list == "english_characters":
return [_pad] + list(_special) + list(_punctuation) + list(_letters) return [_pad] + list(_special) + list(_punctuation) + list(_letters)
elif symbol_list == "english_phonemes": elif symbol_list == "english_phonemes":
return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
else: else:
raise ValueError(f"The `symbol_list` {symbol_list} is not supported." raise ValueError(
f"Supported `symbol_list` includes {available_symbol_set}.") f"The `symbol_list` {symbol_list} is not supported."
f"Supported `symbol_list` includes {available_symbol_set}."
)
def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]: def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
if phonemizer == "DeepPhonemizer": if phonemizer == "DeepPhonemizer":
from dp.phonemizer import Phonemizer from dp.phonemizer import Phonemizer
global _phonemizer global _phonemizer
_other_symbols = ''.join(list(_special) + list(_punctuation)) _other_symbols = "".join(list(_special) + list(_punctuation))
_phone_symbols_re = r'(\[[A-Z]+?\]|' + '[' + _other_symbols + '])' # [\[([A-Z]+?)\]|[-!'(),.:;? ]] _phone_symbols_re = r"(\[[A-Z]+?\]|" + "[" + _other_symbols + "])" # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
if _phonemizer is None: if _phonemizer is None:
# using a global variable so that we don't have to relode checkpoint # using a global variable so that we don't have to relode checkpoint
...@@ -97,7 +102,7 @@ def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]: ...@@ -97,7 +102,7 @@ def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
# Example: # Example:
# sent = "hello world!" # sent = "hello world!"
# '[HH][AH][L][OW] [W][ER][L][D]!' # '[HH][AH][L][OW] [W][ER][L][D]!'
sent = _phonemizer(sent, lang='en_us') sent = _phonemizer(sent, lang="en_us")
# ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!'] # ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
ret = re.findall(_phone_symbols_re, sent) ret = re.findall(_phone_symbols_re, sent)
...@@ -107,16 +112,19 @@ def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]: ...@@ -107,16 +112,19 @@ def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
return ret return ret
else: else:
raise ValueError(f"The `phonemizer` {phonemizer} is not supported. " raise ValueError(
"Supported `symbol_list` includes `'DeepPhonemizer'`.") f"The `phonemizer` {phonemizer} is not supported. " "Supported `symbol_list` includes `'DeepPhonemizer'`."
)
def text_to_sequence(sent: str, def text_to_sequence(
symbol_list: Union[str, List[str]] = "english_characters", sent: str,
phonemizer: Optional[str] = "DeepPhonemizer", symbol_list: Union[str, List[str]] = "english_characters",
checkpoint: Optional[str] = "./en_us_cmudict_forward.pt", phonemizer: Optional[str] = "DeepPhonemizer",
cmudict_root: Optional[str] = "./") -> List[int]: checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. cmudict_root: Optional[str] = "./",
) -> List[int]:
r"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
sent (str): The input sentence to convert to a sequence. sent (str): The input sentence to convert to a sequence.
...@@ -138,19 +146,20 @@ def text_to_sequence(sent: str, ...@@ -138,19 +146,20 @@ def text_to_sequence(sent: str,
[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2] [19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
>>> text_to_sequence("hello world!", "english_phonemes") >>> text_to_sequence("hello world!", "english_phonemes")
[54, 20, 65, 69, 11, 92, 44, 65, 38, 2] [54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
''' """
if symbol_list == "english_phonemes": if symbol_list == "english_phonemes":
if any(param is None for param in [phonemizer, checkpoint, cmudict_root]): if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
raise ValueError( raise ValueError(
"When `symbol_list` is 'english_phonemes', " "When `symbol_list` is 'english_phonemes', "
"all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided.") "all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided."
)
sent = unidecode(sent) # convert to ascii sent = unidecode(sent) # convert to ascii
sent = sent.lower() # lower case sent = sent.lower() # lower case
sent = normalize_numbers(sent) # expand numbers sent = normalize_numbers(sent) # expand numbers
for regex, replacement in _abbreviations: # expand abbreviations for regex, replacement in _abbreviations: # expand abbreviations
sent = re.sub(regex, replacement, sent) sent = re.sub(regex, replacement, sent)
sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace sent = re.sub(_whitespace_re, " ", sent) # collapse whitespace
if isinstance(symbol_list, list): if isinstance(symbol_list, list):
symbols = symbol_list symbols = symbol_list
......
This diff is collapsed.
...@@ -53,21 +53,19 @@ def pad_sequences(batch: List[Tensor]) -> Tuple[Tensor, Tensor]: ...@@ -53,21 +53,19 @@ def pad_sequences(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
Modified from https://github.com/NVIDIA/DeepLearningExamples. Modified from https://github.com/NVIDIA/DeepLearningExamples.
""" """
input_lengths, ids_sorted_decreasing = torch.sort( input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x) for x in batch]), dim=0, descending=True)
torch.LongTensor([len(x) for x in batch]), dim=0, descending=True)
max_input_len = input_lengths[0] max_input_len = input_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len) text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_() text_padded.zero_()
for i in range(len(ids_sorted_decreasing)): for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]] text = batch[ids_sorted_decreasing[i]]
text_padded[i, :text.size(0)] = text text_padded[i, : text.size(0)] = text
return text_padded, input_lengths return text_padded, input_lengths
def prepare_input_sequence(texts: List[str], def prepare_input_sequence(texts: List[str], text_processor: Callable[[str], List[int]]) -> Tuple[Tensor, Tensor]:
text_processor: Callable[[str], List[int]]) -> Tuple[Tensor, Tensor]:
d = [] d = []
for text in texts: for text in texts:
d.append(torch.IntTensor(text_processor(text)[:])) d.append(torch.IntTensor(text_processor(text)[:]))
......
...@@ -51,7 +51,11 @@ class Processed(torch.utils.data.Dataset): ...@@ -51,7 +51,11 @@ class Processed(torch.utils.data.Dataset):
def split_process_librispeech( def split_process_librispeech(
datasets, transforms, language_model, root, folder_in_archive, datasets,
transforms,
language_model,
root,
folder_in_archive,
): ):
def create(tags, cache=True): def create(tags, cache=True):
...@@ -66,7 +70,10 @@ def split_process_librispeech( ...@@ -66,7 +70,10 @@ def split_process_librispeech(
[ [
Processed( Processed(
LIBRISPEECH( LIBRISPEECH(
root, tag, folder_in_archive=folder_in_archive, download=False, root,
tag,
folder_in_archive=folder_in_archive,
download=False,
), ),
transform, transform,
language_model.encode, language_model.encode,
......
...@@ -7,16 +7,15 @@ from time import time ...@@ -7,16 +7,15 @@ from time import time
import torch import torch
import torchaudio import torchaudio
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from torch.optim import SGD, Adadelta, Adam, AdamW from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter from torchaudio.models.wav2letter import Wav2Letter
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from transforms import Normalize, UnsqueezeFirst from transforms import Normalize, UnsqueezeFirst
from utils import MetricLogger, count_parameters, save_checkpoint from utils import MetricLogger, count_parameters, save_checkpoint
...@@ -80,20 +79,14 @@ def parse_args(): ...@@ -80,20 +79,14 @@ def parse_args():
metavar="N", metavar="N",
help="number of total epochs to run", help="number of total epochs to run",
) )
parser.add_argument( parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number")
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument( parser.add_argument(
"--reduce-lr-valid", "--reduce-lr-valid",
action="store_true", action="store_true",
help="reduce learning rate based on validation loss", help="reduce learning rate based on validation loss",
) )
parser.add_argument( parser.add_argument("--normalize", action="store_true", help="normalize model input")
"--normalize", action="store_true", help="normalize model input" parser.add_argument("--progress-bar", action="store_true", help="use progress bar while training")
)
parser.add_argument(
"--progress-bar", action="store_true", help="use progress bar while training"
)
parser.add_argument( parser.add_argument(
"--decoder", "--decoder",
metavar="D", metavar="D",
...@@ -101,9 +94,7 @@ def parse_args(): ...@@ -101,9 +94,7 @@ def parse_args():
choices=["greedy"], choices=["greedy"],
help="decoder to use", help="decoder to use",
) )
parser.add_argument( parser.add_argument("--batch-size", default=128, type=int, metavar="N", help="mini-batch size")
"--batch-size", default=128, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument( parser.add_argument(
"--n-bins", "--n-bins",
default=13, default=13,
...@@ -139,12 +130,8 @@ def parse_args(): ...@@ -139,12 +130,8 @@ def parse_args():
metavar="GAMMA", metavar="GAMMA",
help="learning rate exponential decay constant", help="learning rate exponential decay constant",
) )
parser.add_argument( parser.add_argument("--momentum", default=0.8, type=float, metavar="M", help="momentum")
"--momentum", default=0.8, type=float, metavar="M", help="momentum" parser.add_argument("--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay")
)
parser.add_argument(
"--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay"
)
parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8) parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8)
parser.add_argument("--rho", metavar="RHO", type=float, default=0.95) parser.add_argument("--rho", metavar="RHO", type=float, default=0.95)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0) parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0)
...@@ -172,13 +159,9 @@ def parse_args(): ...@@ -172,13 +159,9 @@ def parse_args():
type=str, type=str,
help="select which part of librispeech to validate with", help="select which part of librispeech to validate with",
) )
parser.add_argument( parser.add_argument("--distributed", action="store_true", help="enable DistributedDataParallel")
"--distributed", action="store_true", help="enable DistributedDataParallel"
)
parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument( parser.add_argument("--world-size", type=int, default=8, help="the world size to initiate DPP")
"--world-size", type=int, default=8, help="the world size to initiate DPP"
)
parser.add_argument("--jit", action="store_true", help="if used, model is jitted") parser.add_argument("--jit", action="store_true", help="if used, model is jitted")
args = parser.parse_args() args = parser.parse_args()
...@@ -263,9 +246,7 @@ def train_one_epoch( ...@@ -263,9 +246,7 @@ def train_one_epoch(
metric = MetricLogger("train", disable=disable_logger) metric = MetricLogger("train", disable=disable_logger)
metric["epoch"] = epoch metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator( for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2):
data_loader, maxsize=2
):
start = time() start = time()
inputs = inputs.to(device, non_blocking=True) inputs = inputs.to(device, non_blocking=True)
...@@ -286,9 +267,7 @@ def train_one_epoch( ...@@ -286,9 +267,7 @@ def train_one_epoch(
loss.backward() loss.backward()
if clip_grad > 0: if clip_grad > 0:
metric["gradient"] = torch.nn.utils.clip_grad_norm_( metric["gradient"] = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
model.parameters(), clip_grad
)
optimizer.step() optimizer.step()
...@@ -335,9 +314,7 @@ def evaluate( ...@@ -335,9 +314,7 @@ def evaluate(
metric = MetricLogger("validation", disable=disable_logger) metric = MetricLogger("validation", disable=disable_logger)
metric["epoch"] = epoch metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator( for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2):
data_loader, maxsize=2
):
inputs = inputs.to(device, non_blocking=True) inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True)
...@@ -351,9 +328,7 @@ def evaluate( ...@@ -351,9 +328,7 @@ def evaluate(
# input_lengths: batch size # input_lengths: batch size
# target_lengths: batch size # target_lengths: batch size
metric["cumulative loss"] += criterion( metric["cumulative loss"] += criterion(outputs, targets, tensors_lengths, target_lengths).item()
outputs, targets, tensors_lengths, target_lengths
).item()
metric["dataset length"] += len(inputs) metric["dataset length"] += len(inputs)
metric["iteration"] += 1 metric["iteration"] += 1
...@@ -518,9 +493,7 @@ def main(rank, args): ...@@ -518,9 +493,7 @@ def main(rank, args):
else: else:
raise ValueError("Selected scheduler not supported") raise ValueError("Selected scheduler not supported")
criterion = torch.nn.CTCLoss( criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank], zero_infinity=False)
blank=language_model.mapping[char_blank], zero_infinity=False
)
# Data Loader # Data Loader
...@@ -569,9 +542,7 @@ def main(rank, args): ...@@ -569,9 +542,7 @@ def main(rank, args):
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"]) scheduler.load_state_dict(checkpoint["scheduler"])
logging.info( logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"])
"Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
)
else: else:
logging.info("Checkpoint: not found") logging.info("Checkpoint: not found")
...@@ -649,9 +620,7 @@ def main(rank, args): ...@@ -649,9 +620,7 @@ def main(rank, args):
def spawn_main(main, args): def spawn_main(main, args):
if args.distributed: if args.distributed:
torch.multiprocessing.spawn( torch.multiprocessing.spawn(main, args=(args,), nprocs=args.world_size, join=True)
main, args=(args,), nprocs=args.world_size, join=True
)
else: else:
main(0, args) main(0, args)
......
import random import random
import torch import torch
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
from torch.utils.data.dataset import random_split from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH, LIBRITTS from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding from torchaudio.transforms import MuLawEncoding
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
class MapMemoryCache(torch.utils.data.Dataset): class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""
"""
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
...@@ -47,16 +45,16 @@ class Processed(torch.utils.data.Dataset): ...@@ -47,16 +45,16 @@ class Processed(torch.utils.data.Dataset):
def split_process_dataset(args, transforms): def split_process_dataset(args, transforms):
if args.dataset == 'ljspeech': if args.dataset == "ljspeech":
data = LJSPEECH(root=args.file_path, download=False) data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio) val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length] lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths) train_dataset, val_dataset = random_split(data, lengths)
elif args.dataset == 'libritts': elif args.dataset == "libritts":
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False) train_dataset = LIBRITTS(root=args.file_path, url="train-clean-100", download=False)
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False) val_dataset = LIBRITTS(root=args.file_path, url="dev-clean", download=False)
else: else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}") raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
...@@ -88,14 +86,8 @@ def collate_factory(args): ...@@ -88,14 +86,8 @@ def collate_factory(args):
# random start postion in waveform # random start postion in waveform
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]
waveform_combine = [ waveform_combine = [x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)]
x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1] specgram = [x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] for i, x in enumerate(batch)]
for i, x in enumerate(batch)
]
specgram = [
x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length]
for i, x in enumerate(batch)
]
specgram = torch.stack(specgram) specgram = torch.stack(specgram)
waveform_combine = torch.stack(waveform_combine) waveform_combine = torch.stack(waveform_combine)
......
...@@ -2,44 +2,46 @@ import argparse ...@@ -2,44 +2,46 @@ import argparse
import torch import torch
import torchaudio import torchaudio
from torchaudio.transforms import MelSpectrogram from processing import NormalizeDB
from torchaudio.datasets import LJSPEECH
from torchaudio.models import wavernn from torchaudio.models import wavernn
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS
from torchaudio.datasets import LJSPEECH from torchaudio.transforms import MelSpectrogram
from wavernn_inference_wrapper import WaveRNNInferenceWrapper from wavernn_inference_wrapper import WaveRNNInferenceWrapper
from processing import NormalizeDB
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--output-wav-path", default="./output.wav", type=str, metavar="PATH", "--output-wav-path",
default="./output.wav",
type=str,
metavar="PATH",
help="The path to output the reconstructed wav file.", help="The path to output the reconstructed wav file.",
) )
parser.add_argument( parser.add_argument(
"--jit", default=False, action="store_true", "--jit", default=False, action="store_true", help="If used, the model and inference function is jitted."
help="If used, the model and inference function is jitted."
)
parser.add_argument(
"--no-batch-inference", default=False, action="store_true",
help="Don't use batch inference."
) )
parser.add_argument("--no-batch-inference", default=False, action="store_true", help="Don't use batch inference.")
parser.add_argument( parser.add_argument(
"--no-mulaw", default=False, action="store_true", "--no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decoder the signal."
help="Don't use mulaw decoder to decoder the signal."
) )
parser.add_argument( parser.add_argument(
"--checkpoint-name", default="wavernn_10k_epochs_8bits_ljspeech", "--checkpoint-name",
default="wavernn_10k_epochs_8bits_ljspeech",
choices=list(_MODEL_CONFIG_AND_URLS.keys()), choices=list(_MODEL_CONFIG_AND_URLS.keys()),
help="Select the WaveRNN checkpoint." help="Select the WaveRNN checkpoint.",
) )
parser.add_argument( parser.add_argument(
"--batch-timesteps", default=100, type=int, "--batch-timesteps",
default=100,
type=int,
help="The time steps for each batch. Only used when batch inference is used", help="The time steps for each batch. Only used when batch inference is used",
) )
parser.add_argument( parser.add_argument(
"--batch-overlap", default=5, type=int, "--batch-overlap",
default=5,
type=int,
help="The overlapping time steps between batches. Only used when batch inference is used", help="The overlapping time steps between batches. Only used when batch inference is used",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -51,15 +53,15 @@ def main(args): ...@@ -51,15 +53,15 @@ def main(args):
waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0] waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0]
mel_kwargs = { mel_kwargs = {
'sample_rate': sample_rate, "sample_rate": sample_rate,
'n_fft': 2048, "n_fft": 2048,
'f_min': 40., "f_min": 40.0,
'n_mels': 80, "n_mels": 80,
'win_length': 1100, "win_length": 1100,
'hop_length': 275, "hop_length": 275,
'mel_scale': 'slaney', "mel_scale": "slaney",
'norm': 'slaney', "norm": "slaney",
'power': 1, "power": 1,
} }
transforms = torch.nn.Sequential( transforms = torch.nn.Sequential(
MelSpectrogram(**mel_kwargs), MelSpectrogram(**mel_kwargs),
...@@ -74,11 +76,13 @@ def main(args): ...@@ -74,11 +76,13 @@ def main(args):
wavernn_inference_model = torch.jit.script(wavernn_inference_model) wavernn_inference_model = torch.jit.script(wavernn_inference_model)
with torch.no_grad(): with torch.no_grad():
output = wavernn_inference_model(mel_specgram.to(device), output = wavernn_inference_model(
mulaw=(not args.no_mulaw), mel_specgram.to(device),
batched=(not args.no_batch_inference), mulaw=(not args.no_mulaw),
timesteps=args.batch_timesteps, batched=(not args.no_batch_inference),
overlap=args.batch_overlap,) timesteps=args.batch_timesteps,
overlap=args.batch_overlap,
)
torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate) torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
......
...@@ -6,8 +6,7 @@ from torch.nn import functional as F ...@@ -6,8 +6,7 @@ from torch.nn import functional as F
class LongCrossEntropyLoss(nn.Module): class LongCrossEntropyLoss(nn.Module):
r""" CrossEntropy loss r"""CrossEntropy loss"""
"""
def __init__(self): def __init__(self):
super(LongCrossEntropyLoss, self).__init__() super(LongCrossEntropyLoss, self).__init__()
...@@ -21,7 +20,7 @@ class LongCrossEntropyLoss(nn.Module): ...@@ -21,7 +20,7 @@ class LongCrossEntropyLoss(nn.Module):
class MoLLoss(nn.Module): class MoLLoss(nn.Module):
r""" Discretized mixture of logistic distributions loss r"""Discretized mixture of logistic distributions loss
Adapted from wavenet vocoder Adapted from wavenet vocoder
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py) (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
...@@ -57,10 +56,8 @@ class MoLLoss(nn.Module): ...@@ -57,10 +56,8 @@ class MoLLoss(nn.Module):
# unpack parameters (n_batch, n_time, num_mixtures) x 3 # unpack parameters (n_batch, n_time, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix] logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix: 2 * nr_mix] means = y_hat[:, :, nr_mix : 2 * nr_mix]
log_scales = torch.clamp( log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=self.log_scale_min)
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min
)
# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
y = y.expand_as(means) y = y.expand_as(means)
...@@ -89,15 +86,11 @@ class MoLLoss(nn.Module): ...@@ -89,15 +86,11 @@ class MoLLoss(nn.Module):
inner_inner_cond = (cdf_delta > 1e-5).float() inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log( inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
torch.clamp(cdf_delta, min=1e-12)
) + (1.0 - inner_inner_cond) * (
log_pdf_mid - math.log((self.num_classes - 1) / 2) log_pdf_mid - math.log((self.num_classes - 1) / 2)
) )
inner_cond = (y > 0.999).float() inner_cond = (y > 0.999).float()
inner_out = ( inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
)
cond = (y < -0.999).float() cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
...@@ -110,8 +103,7 @@ class MoLLoss(nn.Module): ...@@ -110,8 +103,7 @@ class MoLLoss(nn.Module):
def _log_sum_exp(x): def _log_sum_exp(x):
r""" Numerically stable log_sum_exp implementation that prevents overflow r"""Numerically stable log_sum_exp implementation that prevents overflow"""
"""
axis = len(x.size()) - 1 axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis) m, _ = torch.max(x, dim=axis)
......
...@@ -8,14 +8,13 @@ from typing import List ...@@ -8,14 +8,13 @@ from typing import List
import torch import torch
import torchaudio import torchaudio
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint from utils import MetricLogger, count_parameters, save_checkpoint
...@@ -43,9 +42,7 @@ def parse_args(): ...@@ -43,9 +42,7 @@ def parse_args():
metavar="N", metavar="N",
help="number of total epochs to run", help="number of total epochs to run",
) )
parser.add_argument( parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number")
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument( parser.add_argument(
"--print-freq", "--print-freq",
default=10, default=10,
...@@ -60,11 +57,13 @@ def parse_args(): ...@@ -60,11 +57,13 @@ def parse_args():
type=str, type=str,
help="select dataset to train with", help="select dataset to train with",
) )
parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size")
parser.add_argument( parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size" "--learning-rate",
) default=1e-4,
parser.add_argument( type=float,
"--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate", metavar="LR",
help="learning rate",
) )
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0) parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0)
parser.add_argument( parser.add_argument(
...@@ -73,9 +72,7 @@ def parse_args(): ...@@ -73,9 +72,7 @@ def parse_args():
action="store_true", action="store_true",
help="if used, waveform is mulaw encoded", help="if used, waveform is mulaw encoded",
) )
parser.add_argument( parser.add_argument("--jit", default=False, action="store_true", help="if used, model is jitted")
"--jit", default=False, action="store_true", help="if used, model is jitted"
)
parser.add_argument( parser.add_argument(
"--upsample-scales", "--upsample-scales",
default=[5, 5, 11], default=[5, 5, 11],
...@@ -83,7 +80,10 @@ def parse_args(): ...@@ -83,7 +80,10 @@ def parse_args():
help="the list of upsample scales", help="the list of upsample scales",
) )
parser.add_argument( parser.add_argument(
"--n-bits", default=8, type=int, help="the bits of output waveform", "--n-bits",
default=8,
type=int,
help="the bits of output waveform",
) )
parser.add_argument( parser.add_argument(
"--sample-rate", "--sample-rate",
...@@ -98,10 +98,16 @@ def parse_args(): ...@@ -98,10 +98,16 @@ def parse_args():
help="the number of samples between the starts of consecutive frames", help="the number of samples between the starts of consecutive frames",
) )
parser.add_argument( parser.add_argument(
"--win-length", default=1100, type=int, help="the length of the STFT window", "--win-length",
default=1100,
type=int,
help="the length of the STFT window",
) )
parser.add_argument( parser.add_argument(
"--f-min", default=40.0, type=float, help="the minimum frequency", "--f-min",
default=40.0,
type=float,
help="the minimum frequency",
) )
parser.add_argument( parser.add_argument(
"--min-level-db", "--min-level-db",
...@@ -110,13 +116,22 @@ def parse_args(): ...@@ -110,13 +116,22 @@ def parse_args():
help="the minimum db value for spectrogam normalization", help="the minimum db value for spectrogam normalization",
) )
parser.add_argument( parser.add_argument(
"--n-res-block", default=10, type=int, help="the number of ResBlock in stack", "--n-res-block",
default=10,
type=int,
help="the number of ResBlock in stack",
) )
parser.add_argument( parser.add_argument(
"--n-rnn", default=512, type=int, help="the dimension of RNN layer", "--n-rnn",
default=512,
type=int,
help="the dimension of RNN layer",
) )
parser.add_argument( parser.add_argument(
"--n-fc", default=512, type=int, help="the dimension of fully connected layer", "--n-fc",
default=512,
type=int,
help="the dimension of fully connected layer",
) )
parser.add_argument( parser.add_argument(
"--kernel-size", "--kernel-size",
...@@ -125,7 +140,10 @@ def parse_args(): ...@@ -125,7 +140,10 @@ def parse_args():
help="the number of kernel size in the first Conv1d layer", help="the number of kernel size in the first Conv1d layer",
) )
parser.add_argument( parser.add_argument(
"--n-freq", default=80, type=int, help="the number of spectrogram bins to use", "--n-freq",
default=80,
type=int,
help="the number of spectrogram bins to use",
) )
parser.add_argument( parser.add_argument(
"--n-hidden-melresnet", "--n-hidden-melresnet",
...@@ -134,10 +152,16 @@ def parse_args(): ...@@ -134,10 +152,16 @@ def parse_args():
help="the number of hidden dimensions of resblock in melresnet", help="the number of hidden dimensions of resblock in melresnet",
) )
parser.add_argument( parser.add_argument(
"--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet", "--n-output-melresnet",
default=128,
type=int,
help="the output dimension of melresnet",
) )
parser.add_argument( parser.add_argument(
"--n-fft", default=2048, type=int, help="the number of Fourier bins", "--n-fft",
default=2048,
type=int,
help="the number of Fourier bins",
) )
parser.add_argument( parser.add_argument(
"--loss", "--loss",
...@@ -159,10 +183,16 @@ def parse_args(): ...@@ -159,10 +183,16 @@ def parse_args():
help="the ratio of waveforms for validation", help="the ratio of waveforms for validation",
) )
parser.add_argument( parser.add_argument(
"--file-path", default="", type=str, help="the path of audio files", "--file-path",
default="",
type=str,
help="the path of audio files",
) )
parser.add_argument( parser.add_argument(
"--normalization", default=True, action="store_true", help="if True, spectrogram is normalized", "--normalization",
default=True,
action="store_true",
help="if True, spectrogram is normalized",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -199,9 +229,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch): ...@@ -199,9 +229,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
loss.backward() loss.backward()
if args.clip_grad > 0: if args.clip_grad > 0:
gradient = torch.nn.utils.clip_grad_norm_( gradient = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
model.parameters(), args.clip_grad
)
sums["gradient"] += gradient.item() sums["gradient"] += gradient.item()
metric["gradient"] = gradient.item() metric["gradient"] = gradient.item()
...@@ -271,8 +299,8 @@ def main(args): ...@@ -271,8 +299,8 @@ def main(args):
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
n_mels=args.n_freq, n_mels=args.n_freq,
f_min=args.f_min, f_min=args.f_min,
mel_scale='slaney', mel_scale="slaney",
norm='slaney', norm="slaney",
**melkwargs, **melkwargs,
), ),
NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization), NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
...@@ -349,9 +377,7 @@ def main(args): ...@@ -349,9 +377,7 @@ def main(args):
model.load_state_dict(checkpoint["state_dict"]) model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
logging.info( logging.info(f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}")
f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}"
)
else: else:
logging.info("Checkpoint: not found") logging.info("Checkpoint: not found")
...@@ -369,7 +395,12 @@ def main(args): ...@@ -369,7 +395,12 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
train_one_epoch( train_one_epoch(
model, criterion, optimizer, train_loader, devices[0], epoch, model,
criterion,
optimizer,
train_loader,
devices[0],
epoch,
) )
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
......
...@@ -3,8 +3,7 @@ import torch.nn as nn ...@@ -3,8 +3,7 @@ import torch.nn as nn
class NormalizeDB(nn.Module): class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value r"""Normalize the spectrogram with a minimum db value"""
"""
def __init__(self, min_level_db, normalization): def __init__(self, min_level_db, normalization):
super().__init__() super().__init__()
...@@ -14,15 +13,12 @@ class NormalizeDB(nn.Module): ...@@ -14,15 +13,12 @@ class NormalizeDB(nn.Module):
def forward(self, specgram): def forward(self, specgram):
specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5)) specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5))
if self.normalization: if self.normalization:
return torch.clamp( return torch.clamp((self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1)
(self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
)
return specgram return specgram
def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tensor: def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
"""
assert abs(waveform).max() <= 1.0 assert abs(waveform).max() <= 1.0
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
...@@ -30,7 +26,6 @@ def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tens ...@@ -30,7 +26,6 @@ def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tens
def bits_to_normalized_waveform(label: torch.Tensor, bits: int) -> torch.Tensor: def bits_to_normalized_waveform(label: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]"""
"""
return 2 * label / (2 ** bits - 1.0) - 1.0 return 2 * label / (2 ** bits - 1.0) - 1.0
...@@ -7,8 +7,7 @@ import torch ...@@ -7,8 +7,7 @@ import torch
class MetricLogger: class MetricLogger:
r"""Logger for model metrics r"""Logger for model metrics"""
"""
def __init__(self, group, print_freq=1): def __init__(self, group, print_freq=1):
self.print_freq = print_freq self.print_freq = print_freq
...@@ -55,7 +54,6 @@ def save_checkpoint(state, is_best, filename): ...@@ -55,7 +54,6 @@ def save_checkpoint(state, is_best, filename):
def count_parameters(model): def count_parameters(model):
r"""Count the total number of parameters in the model r"""Count the total number of parameters in the model"""
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
...@@ -21,16 +21,15 @@ ...@@ -21,16 +21,15 @@
# ***************************************************************************** # *****************************************************************************
from torchaudio.models.wavernn import WaveRNN
import torch import torch
import torchaudio import torchaudio
from torch import Tensor
from processing import normalized_waveform_to_bits from processing import normalized_waveform_to_bits
from torch import Tensor
from torchaudio.models.wavernn import WaveRNN
def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor: def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference. r"""Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold(). Overlap will be used for crossfading in xfade_and_unfold().
x = [[h1, h2, ... hn]] x = [[h1, h2, ... hn]]
...@@ -47,7 +46,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor: ...@@ -47,7 +46,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
Return: Return:
folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel). folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
''' """
_, channels, total_len = x.size() _, channels, total_len = x.size()
...@@ -74,7 +73,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor: ...@@ -74,7 +73,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor: def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array. r"""Applies a crossfade and unfolds into a 1d array.
y = [[seq1], y = [[seq1],
[seq2], [seq2],
...@@ -93,7 +92,7 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor: ...@@ -93,7 +92,7 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
Returns: Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len). unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
''' """
num_folds, channels, length = y.shape num_folds, channels, length = y.shape
timesteps = length - 2 * overlap timesteps = length - 2 * overlap
...@@ -130,17 +129,13 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor: ...@@ -130,17 +129,13 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
class WaveRNNInferenceWrapper(torch.nn.Module): class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN): def __init__(self, wavernn: WaveRNN):
super().__init__() super().__init__()
self.wavernn_model = wavernn self.wavernn_model = wavernn
def forward(self, def forward(
specgram: Tensor, self, specgram: Tensor, mulaw: bool = True, batched: bool = True, timesteps: int = 100, overlap: int = 5
mulaw: bool = True, ) -> Tensor:
batched: bool = True,
timesteps: int = 100,
overlap: int = 5) -> Tensor:
r"""Inference function for WaveRNN. r"""Inference function for WaveRNN.
Based on the implementation from Based on the implementation from
......
from . import ( from . import train, trainer
train,
trainer
)
__all__ = ['train', 'trainer'] __all__ = ["train", "trainer"]
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Train Conv-TasNet""" """Train Conv-TasNet"""
import time
import pathlib
import argparse import argparse
import pathlib
import time
import conv_tasnet
import torch import torch
import torchaudio import torchaudio
import torchaudio.models import torchaudio.models
import conv_tasnet
from utils import dist_utils from utils import dist_utils
from utils.dataset import utils as dataset_utils from utils.dataset import utils as dataset_utils
...@@ -16,15 +15,14 @@ _LG = dist_utils.getLogger(__name__) ...@@ -16,15 +15,14 @@ _LG = dist_utils.getLogger(__name__)
def _parse_args(args): def _parse_args(args):
parser = argparse.ArgumentParser(description=__doc__,) parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument( parser.add_argument(
"--debug", "--debug", action="store_true", help="Enable debug behavior. Each epoch will end with just one batch."
action="store_true",
help="Enable debug behavior. Each epoch will end with just one batch.")
group = parser.add_argument_group("Model Options")
group.add_argument(
"--num-speakers", required=True, type=int, help="The number of speakers."
) )
group = parser.add_argument_group("Model Options")
group.add_argument("--num-speakers", required=True, type=int, help="The number of speakers.")
group = parser.add_argument_group("Dataset Options") group = parser.add_argument_group("Dataset Options")
group.add_argument( group.add_argument(
"--sample-rate", "--sample-rate",
...@@ -132,7 +130,8 @@ def _get_model( ...@@ -132,7 +130,8 @@ def _get_model(
_LG.info_on_master(" - X: %d", msk_num_layers) _LG.info_on_master(" - X: %d", msk_num_layers)
_LG.info_on_master(" - R: %d", msk_num_stacks) _LG.info_on_master(" - R: %d", msk_num_stacks)
_LG.info_on_master( _LG.info_on_master(
" - Receptive Field: %s [samples]", model.mask_generator.receptive_field, " - Receptive Field: %s [samples]",
model.mask_generator.receptive_field,
) )
return model return model
...@@ -141,11 +140,9 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_ ...@@ -141,11 +140,9 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, dataset_dir, num_speakers, sample_rate, task dataset_type, dataset_dir, num_speakers, sample_rate, task
) )
train_collate_fn = dataset_utils.get_collate_fn( train_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="train", sample_rate=sample_rate, duration=4)
dataset_type, mode='train', sample_rate=sample_rate, duration=4
)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test') test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="test")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
...@@ -173,8 +170,12 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_ ...@@ -173,8 +170,12 @@ def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_
def _write_header(log_path, args): def _write_header(log_path, args):
rows = [ rows = [
[f"# torch: {torch.__version__}", ], [
[f"# torchaudio: {torchaudio.__version__}", ] f"# torch: {torch.__version__}",
],
[
f"# torchaudio: {torchaudio.__version__}",
],
] ]
rows.append(["# arguments"]) rows.append(["# arguments"])
for key, item in vars(args).items(): for key, item in vars(args).items():
...@@ -212,9 +213,7 @@ def train(args): ...@@ -212,9 +213,7 @@ def train(args):
model = _get_model(num_sources=args.num_speakers) model = _get_model(num_sources=args.num_speakers)
model.to(device) model.to(device)
model = torch.nn.parallel.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device] if torch.cuda.is_available() else None)
model, device_ids=[device] if torch.cuda.is_available() else None
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
if args.resume: if args.resume:
...@@ -222,13 +221,9 @@ def train(args): ...@@ -222,13 +221,9 @@ def train(args):
model.module.load_state_dict(checkpoint["model"]) model.module.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
else: else:
dist_utils.synchronize_params( dist_utils.synchronize_params(str(args.save_dir / "tmp.pt"), device, model, optimizer)
str(args.save_dir / "tmp.pt"), device, model, optimizer
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)
optimizer, mode="max", factor=0.5, patience=3
)
train_loader, valid_loader, eval_loader = _get_dataloader( train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset, args.dataset,
......
import time import time
from typing import Tuple
from collections import namedtuple from collections import namedtuple
from typing import Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from utils import dist_utils, metrics from utils import dist_utils, metrics
_LG = dist_utils.getLogger(__name__) _LG = dist_utils.getLogger(__name__)
Metric = namedtuple("SNR", ["si_snri", "sdri"]) Metric = namedtuple("SNR", ["si_snri", "sdri"])
Metric.__str__ = ( Metric.__str__ = lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
)
def si_sdr_improvement( def si_sdr_improvement(
estimate: torch.Tensor, estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor, mask: torch.Tensor
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi). """Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi).
...@@ -66,11 +60,7 @@ class OccasionalLogger: ...@@ -66,11 +60,7 @@ class OccasionalLogger:
def log(self, metric, progress, force=False): def log(self, metric, progress, force=False):
now = time.monotonic() now = time.monotonic()
if ( if force or now > self.last_time + self.time_interval or progress > self.last_progress + self.progress_interval:
force
or now > self.last_time + self.time_interval
or progress > self.last_progress + self.progress_interval
):
self.last_time = now self.last_time = now
self.last_progress = progress self.last_progress = progress
_LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress) _LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress)
...@@ -117,9 +107,7 @@ class Trainer: ...@@ -117,9 +107,7 @@ class Trainer:
loss = -si_snri loss = -si_snri
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip, norm_type=2.0)
self.model.parameters(), self.grad_clip, norm_type=2.0
)
self.optimizer.step() self.optimizer.step()
metric = Metric(si_snri.item(), sdri.item()) metric = Metric(si_snri.item(), sdri.item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment