Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
"""
Text-to-speech pipeline using Tacotron2.
"""
from functools import partial
import argparse
import os
import random
import sys
import torch
import torchaudio
import numpy as np
from torchaudio.models import Tacotron2
from torchaudio.models import tacotron2 as pretrained_tacotron2
from utils import prepare_input_sequence
from datasets import InverseSpectralNormalization
from text.text_preprocessing import (
available_symbol_set,
available_phonemizers,
get_symbol_list,
text_to_sequence,
)
def parse_args():
r"""
Parse commandline arguments.
"""
from torchaudio.models.tacotron2 import _MODEL_CONFIG_AND_URLS as tacotron2_config_and_urls
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS as wavernn_config_and_urls
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--checkpoint-name',
type=str,
default=None,
choices=list(tacotron2_config_and_urls.keys()),
help='[string] The name of the checkpoint to load.'
)
parser.add_argument(
'--checkpoint-path',
type=str,
default=None,
help='[string] Path to the checkpoint file.'
)
parser.add_argument(
'--output-path',
type=str,
default="./audio.wav",
help='[string] Path to the output .wav file.'
)
parser.add_argument(
'--input-text',
'-i',
type=str,
default="Hello world",
help='[string] Type in something here and TTS will generate it!'
)
parser.add_argument(
'--vocoder',
default='nvidia_waveglow',
choices=['griffin_lim', 'wavernn', 'nvidia_waveglow'],
type=str,
help="Select the vocoder to use.",
)
parser.add_argument(
"--jit",
default=False,
action="store_true",
help="If used, the model and inference function is jitted."
)
preprocessor = parser.add_argument_group('text preprocessor setup')
preprocessor.add_argument(
'--text-preprocessor',
default='english_characters',
type=str,
choices=available_symbol_set,
help='select text preprocessor to use.'
)
preprocessor.add_argument(
'--phonemizer',
default="DeepPhonemizer",
type=str,
choices=available_phonemizers,
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"'
)
preprocessor.add_argument(
'--phonemizer-checkpoint',
default="./en_us_cmudict_forward.pt",
type=str,
help='the path or name of the checkpoint for the phonemizer, '
'only used when text-preprocessor is "english_phonemes"'
)
preprocessor.add_argument(
'--cmudict-root',
default="./",
type=str,
help='the root directory for storing CMU dictionary files'
)
audio = parser.add_argument_group('audio parameters')
audio.add_argument(
'--sample-rate',
default=22050,
type=int,
help='Sampling rate'
)
audio.add_argument(
'--n-fft',
default=1024,
type=int,
help='Filter length for STFT'
)
audio.add_argument(
'--n-mels',
default=80,
type=int,
help=''
)
audio.add_argument(
'--mel-fmin',
default=0.0,
type=float,
help='Minimum mel frequency'
)
audio.add_argument(
'--mel-fmax',
default=8000.0,
type=float,
help='Maximum mel frequency'
)
# parameters for WaveRNN
wavernn = parser.add_argument_group('WaveRNN parameters')
wavernn.add_argument(
'--wavernn-checkpoint-name',
default="wavernn_10k_epochs_8bits_ljspeech",
choices=list(wavernn_config_and_urls.keys()),
help="Select the WaveRNN checkpoint."
)
wavernn.add_argument(
"--wavernn-loss",
default="crossentropy",
choices=["crossentropy"],
type=str,
help="The type of loss the WaveRNN pretrained model is trained on.",
)
wavernn.add_argument(
"--wavernn-no-batch-inference",
default=False,
action="store_true",
help="Don't use batch inference for WaveRNN inference."
)
wavernn.add_argument(
"--wavernn-no-mulaw",
default=False,
action="store_true",
help="Don't use mulaw decoder to decode the signal."
)
wavernn.add_argument(
"--wavernn-batch-timesteps",
default=11000,
type=int,
help="The time steps for each batch. Only used when batch inference is used",
)
wavernn.add_argument(
"--wavernn-batch-overlap",
default=550,
type=int,
help="The overlapping time steps between batches. Only used when batch inference is used",
)
return parser
def unwrap_distributed(state_dict):
r"""torch.distributed.DistributedDataParallel wraps the model with an additional "module.".
This function unwraps this layer so that the weights can be loaded on models with a single GPU.
Args:
state_dict: Original state_dict.
Return:
unwrapped_state_dict: Unwrapped state_dict.
"""
return {k.replace('module.', ''): v for k, v in state_dict.items()}
def nvidia_waveglow_vocode(mel_specgram, device, jit=False):
waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp16')
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()
if args.jit:
raise ValueError("Vocoder option `nvidia_waveglow is not jittable.")
with torch.no_grad():
waveform = waveglow.infer(mel_specgram).cpu()
return waveform
def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_no_mulaw,
wavernn_no_batch_inference, wavernn_batch_timesteps, wavernn_batch_overlap,
device, jit):
from torchaudio.models import wavernn
sys.path.append(os.path.join(os.path.dirname(__file__), "../pipeline_wavernn"))
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
from processing import NormalizeDB
wavernn_model = wavernn(wavernn_checkpoint_name).eval().to(device)
wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
if jit:
wavernn_inference_model = torch.jit.script(wavernn_inference_model)
# WaveRNN spectro setting for default checkpoint
# n_fft = 2048
# n_mels = 80
# win_length = 1100
# hop_length = 275
# f_min = 40
# f_max = 11025
transforms = torch.nn.Sequential(
InverseSpectralNormalization(),
NormalizeDB(min_level_db=-100, normalization=True),
)
mel_specgram = transforms(mel_specgram.cpu())
with torch.no_grad():
waveform = wavernn_inference_model(mel_specgram.to(device),
loss_name=wavernn_loss,
mulaw=(not wavernn_no_mulaw),
batched=(not wavernn_no_batch_inference),
timesteps=wavernn_batch_timesteps,
overlap=wavernn_batch_overlap,)
return waveform.unsqueeze(0)
def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_fmax, jit, ):
from torchaudio.transforms import GriffinLim, InverseMelScale
inv_norm = InverseSpectralNormalization()
inv_mel = InverseMelScale(
n_stft=(n_fft // 2 + 1),
n_mels=n_mels,
sample_rate=sample_rate,
f_min=mel_fmin,
f_max=mel_fmax,
mel_scale="slaney",
norm='slaney',
)
griffin_lim = GriffinLim(
n_fft=n_fft,
power=1,
hop_length=256,
win_length=1024,
)
vocoder = torch.nn.Sequential(
inv_norm,
inv_mel,
griffin_lim
)
if jit:
vocoder = torch.jit.script(vocoder)
waveform = vocoder(mel_specgram.cpu())
return waveform
def main(args):
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.checkpoint_path is None and args.checkpoint_name is None:
raise ValueError("Either --checkpoint-path or --checkpoint-name must be specified.")
elif args.checkpoint_path is not None and args.checkpoint_name is not None:
raise ValueError("Both --checkpoint-path and --checkpoint-name are specified, "
"can only specify one.")
n_symbols = len(get_symbol_list(args.text_preprocessor))
text_preprocessor = partial(
text_to_sequence,
symbol_list=args.text_preprocessor,
phonemizer=args.phonemizer,
checkpoint=args.phonemizer_checkpoint,
cmudict_root=args.cmudict_root,
)
if args.checkpoint_path is not None:
tacotron2 = Tacotron2(n_symbol=n_symbols)
tacotron2.load_state_dict(
unwrap_distributed(torch.load(args.checkpoint_path, map_location=device)['state_dict']))
tacotron2 = tacotron2.to(device).eval()
elif args.checkpoint_name is not None:
tacotron2 = pretrained_tacotron2(args.checkpoint_name).to(device).eval()
if n_symbols != tacotron2.n_symbols:
raise ValueError("the number of symbols for text_preprocessor ({n_symbols}) "
"should match the number of symbols for the"
"pretrained tacotron2 ({tacotron2.n_symbols}).")
if args.jit:
tacotron2 = torch.jit.script(tacotron2)
sequences, lengths = prepare_input_sequence([args.input_text],
text_processor=text_preprocessor)
sequences, lengths = sequences.long().to(device), lengths.long().to(device)
with torch.no_grad():
mel_specgram, _, _ = tacotron2.infer(sequences, lengths)
if args.vocoder == "nvidia_waveglow":
waveform = nvidia_waveglow_vocode(mel_specgram=mel_specgram, device=device, jit=args.jit)
elif args.vocoder == "wavernn":
waveform = wavernn_vocode(mel_specgram=mel_specgram,
wavernn_checkpoint_name=args.wavernn_checkpoint_name,
wavernn_loss=args.wavernn_loss,
wavernn_no_mulaw=args.wavernn_no_mulaw,
wavernn_no_batch_inference=args.wavernn_no_batch_inference,
wavernn_batch_timesteps=args.wavernn_batch_timesteps,
wavernn_batch_overlap=args.wavernn_batch_overlap,
device=device,
jit=args.jit)
elif args.vocoder == "griffin_lim":
waveform = griffin_lim_vocode(mel_specgram=mel_specgram,
n_fft=args.n_fft,
n_mels=args.n_mels,
sample_rate=args.sample_rate,
mel_fmin=args.mel_fmin,
mel_fmax=args.mel_fmax,
jit=args.jit)
torchaudio.save(args.output_path, waveform, args.sample_rate)
if __name__ == "__main__":
parser = parse_args()
args, _ = parser.parse_known_args()
main(args)
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
from typing import Tuple
from torch import nn, Tensor
class Tacotron2Loss(nn.Module):
"""Tacotron2 loss function modified from:
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py
"""
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss(reduction="mean")
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
def forward(
self,
model_outputs: Tuple[Tensor, Tensor, Tensor],
targets: Tuple[Tensor, Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Pass the input through the Tacotron2 loss.
The original implementation was introduced in
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
[:footcite:`shen2018natural`].
Args:
model_outputs (tuple of three Tensors): The outputs of the
Tacotron2. These outputs should include three items:
(1) the predicted mel spectrogram before the postnet (``mel_specgram``)
with shape (batch, mel, time).
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``)
with shape (batch, mel, time), and
(3) the stop token prediction (``gate_out``) with shape (batch, ).
targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and
stop token with shape (batch, ).
Returns:
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram
with shape ``torch.Size([])``.
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
ground truth mel spectrogram with shape ``torch.Size([])``.
gate_loss (Tensor): The mean binary cross entropy loss of
the prediction on the stop token with shape ``torch.Size([])``.
"""
mel_target, gate_target = targets[0], targets[1]
gate_target = gate_target.view(-1, 1)
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
gate_out = gate_out.view(-1, 1)
mel_loss = self.mse_loss(mel_specgram, mel_target)
mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target)
gate_loss = self.bce_loss(gate_out, gate_target)
return mel_loss, mel_postnet_loss, gate_loss
# *****************************************************************************
# Copyright (c) 2017 Keith Ito
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# *****************************************************************************
"""
Modified from https://github.com/keithito/tacotron
"""
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(text: str) -> str:
return re.sub(_comma_number_re, lambda m: m.group(1).replace(',', ''), text)
def _expand_pounds(text: str) -> str:
return re.sub(_pounds_re, r'\1 pounds', text)
def _expand_dollars_repl_fn(m):
"""The replacement function for expanding dollars."""
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
if len(parts) > 1 and parts[1]:
if len(parts[1]) == 1:
# handle the case where we have one digit after the decimal point
cents = int(parts[1]) * 10
else:
cents = int(parts[1])
else:
cents = 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_dollars(text: str) -> str:
return re.sub(_dollars_re, _expand_dollars_repl_fn, text)
def _expand_decimal_point(text: str) -> str:
return re.sub(_decimal_number_re, lambda m: m.group(1).replace('.', ' point '), text)
def _expand_ordinal(text: str) -> str:
return re.sub(_ordinal_re, lambda m: _inflect.number_to_words(m.group(0)), text)
def _expand_number_repl_fn(m):
"""The replacement function for expanding number."""
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def _expand_number(text: str) -> str:
return re.sub(_number_re, _expand_number_repl_fn, text)
def normalize_numbers(text: str) -> str:
text = _remove_commas(text)
text = _expand_pounds(text)
text = _expand_dollars(text)
text = _expand_decimal_point(text)
text = _expand_ordinal(text)
text = _expand_number(text)
return text
# *****************************************************************************
# Copyright (c) 2017 Keith Ito
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# *****************************************************************************
"""
Modified from https://github.com/keithito/tacotron
"""
from typing import List, Union, Optional
import re
from unidecode import unidecode
from torchaudio.datasets import CMUDict
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
_pad = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'abcdefghijklmnopqrstuvwxyz'
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_phonemizer = None
available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"])
def get_symbol_list(symbol_list: str = "english_characters",
cmudict_root: Optional[str] = "./") -> List[str]:
if symbol_list == "english_characters":
return [_pad] + list(_special) + list(_punctuation) + list(_letters)
elif symbol_list == "english_phonemes":
return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
else:
raise ValueError(f"The `symbol_list` {symbol_list} is not supported."
f"Supported `symbol_list` includes {available_symbol_set}.")
def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
if phonemizer == "DeepPhonemizer":
from dp.phonemizer import Phonemizer
global _phonemizer
_other_symbols = ''.join(list(_special) + list(_punctuation))
_phone_symbols_re = r'(\[[A-Z]+?\]|' + '[' + _other_symbols + '])' # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
if _phonemizer is None:
# using a global variable so that we don't have to relode checkpoint
# everytime this function is called
_phonemizer = Phonemizer.from_checkpoint(checkpoint)
# Example:
# sent = "hello world!"
# '[HH][AH][L][OW] [W][ER][L][D]!'
sent = _phonemizer(sent, lang='en_us')
# ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
ret = re.findall(_phone_symbols_re, sent)
# ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!']
ret = [r.replace("[", "").replace("]", "") for r in ret]
return ret
else:
raise ValueError(f"The `phonemizer` {phonemizer} is not supported. "
"Supported `symbol_list` includes `'DeepPhonemizer'`.")
def text_to_sequence(sent: str,
symbol_list: Union[str, List[str]] = "english_characters",
phonemizer: Optional[str] = "DeepPhonemizer",
checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
cmudict_root: Optional[str] = "./") -> List[int]:
r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
sent (str): The input sentence to convert to a sequence.
symbol_list (str or List of string, optional): When the input is a string, available options include
"english_characters" and "english_phonemes". When the input is a list of string, ``symbol_list`` will
directly be used as the symbol to encode. (Default: "english_characters")
phonemizer (str or None, optional): The phonemizer to use. Only used when ``symbol_list`` is "english_phonemes".
Available options include "DeepPhonemizer". (Default: "DeepPhonemizer")
checkpoint (str or None, optional): The path to the checkpoint of the phonemizer. Only used when
``symbol_list`` is "english_phonemes". (Default: "./en_us_cmudict_forward.pt")
cmudict_root (str or None, optional): The path to the directory where the CMUDict dataset is found or
downloaded. Only used when ``symbol_list`` is "english_phonemes". (Default: "./")
Returns:
List of integers corresponding to the symbols in the sentence.
Examples:
>>> text_to_sequence("hello world!", "english_characters")
[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
>>> text_to_sequence("hello world!", "english_phonemes")
[54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
'''
if symbol_list == "english_phonemes":
if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
raise ValueError(
"When `symbol_list` is 'english_phonemes', "
"all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided.")
sent = unidecode(sent) # convert to ascii
sent = sent.lower() # lower case
sent = normalize_numbers(sent) # expand numbers
for regex, replacement in _abbreviations: # expand abbreviations
sent = re.sub(regex, replacement, sent)
sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace
if isinstance(symbol_list, list):
symbols = symbol_list
elif isinstance(symbol_list, str):
symbols = get_symbol_list(symbol_list, cmudict_root=cmudict_root)
if symbol_list == "english_phonemes":
sent = word_to_phonemes(sent, phonemizer=phonemizer, checkpoint=checkpoint)
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
return [_symbol_to_id[s] for s in sent if s in _symbol_to_id]
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
"""
Modified from
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/train.py
"""
import argparse
from datetime import datetime
from functools import partial
import logging
import random
import os
from time import time
import torch
import torchaudio
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchaudio.models import Tacotron2
from tqdm import tqdm
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from datasets import text_mel_collate_fn, split_process_dataset, SpectralNormalization
from utils import save_checkpoint
from loss import Tacotron2Loss
from text.text_preprocessing import (
available_symbol_set,
available_phonemizers,
get_symbol_list,
text_to_sequence,
)
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(os.path.basename(__file__))
def parse_args(parser):
"""Parse commandline arguments."""
parser.add_argument("--dataset", default="ljspeech", choices=["ljspeech"], type=str,
help="select dataset to train with")
parser.add_argument('--logging-dir', type=str, default=None,
help='directory to save the log files')
parser.add_argument('--dataset-path', type=str, default='./',
help='path to dataset')
parser.add_argument("--val-ratio", default=0.1, type=float,
help="the ratio of waveforms for validation")
parser.add_argument('--anneal-steps', nargs='*',
help='epochs after which decrease learning rate')
parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1,
help='factor for annealing learning rate')
parser.add_argument('--master-addr', default=None, type=str,
help='the address to use for distributed training')
parser.add_argument('--master-port', default=None, type=str,
help='the port to use for distributed training')
preprocessor = parser.add_argument_group('text preprocessor setup')
preprocessor.add_argument('--text-preprocessor', default='english_characters', type=str,
choices=available_symbol_set,
help='select text preprocessor to use.')
preprocessor.add_argument('--phonemizer', type=str, choices=available_phonemizers,
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"')
preprocessor.add_argument('--phonemizer-checkpoint', type=str,
help='the path or name of the checkpoint for the phonemizer, '
'only used when text-preprocessor is "english_phonemes"')
preprocessor.add_argument('--cmudict-root', default="./", type=str,
help='the root directory for storing cmudictionary files')
# training
training = parser.add_argument_group('training setup')
training.add_argument('--epochs', type=int, required=True,
help='number of total epochs to run')
training.add_argument('--checkpoint-path', type=str, default='',
help='checkpoint path. If a file exists, '
'the program will load it and resume training.')
training.add_argument('--workers', default=8, type=int,
help="number of data loading workers")
training.add_argument("--validate-and-checkpoint-freq", default=10, type=int, metavar="N",
help="validation and saving checkpoint frequency in epochs",)
training.add_argument("--logging-freq", default=10, type=int, metavar="N",
help="logging frequency in epochs")
optimization = parser.add_argument_group('optimization setup')
optimization.add_argument('--learning-rate', default=1e-3, type=float,
help='initial learing rate')
optimization.add_argument('--weight-decay', default=1e-6, type=float,
help='weight decay')
optimization.add_argument('--batch-size', default=32, type=int,
help='batch size per GPU')
optimization.add_argument('--grad-clip', default=5.0, type=float,
help='clipping gradient with maximum gradient norm value')
# model parameters
model = parser.add_argument_group('model parameters')
model.add_argument('--mask-padding', action='store_true', default=False,
help='use mask padding')
model.add_argument('--symbols-embedding-dim', default=512, type=int,
help='input embedding dimension')
# encoder
model.add_argument('--encoder-embedding-dim', default=512, type=int,
help='encoder embedding dimension')
model.add_argument('--encoder-n-convolution', default=3, type=int,
help='number of encoder convolutions')
model.add_argument('--encoder-kernel-size', default=5, type=int,
help='encoder kernel size')
# decoder
model.add_argument('--n-frames-per-step', default=1, type=int,
help='number of frames processed per step (currently only 1 is supported)')
model.add_argument('--decoder-rnn-dim', default=1024, type=int,
help='number of units in decoder LSTM')
model.add_argument('--decoder-dropout', default=0.1, type=float,
help='dropout probability for decoder LSTM')
model.add_argument('--decoder-max-step', default=2000, type=int,
help='maximum number of output mel spectrograms')
model.add_argument('--decoder-no-early-stopping', action='store_true', default=False,
help='stop decoding only when all samples are finished')
# attention model
model.add_argument('--attention-hidden-dim', default=128, type=int,
help='dimension of attention hidden representation')
model.add_argument('--attention-rnn-dim', default=1024, type=int,
help='number of units in attention LSTM')
model.add_argument('--attention-location-n-filter', default=32, type=int,
help='number of filters for location-sensitive attention')
model.add_argument('--attention-location-kernel-size', default=31, type=int,
help='kernel size for location-sensitive attention')
model.add_argument('--attention-dropout', default=0.1, type=float,
help='dropout probability for attention LSTM')
model.add_argument('--prenet-dim', default=256, type=int,
help='number of ReLU units in prenet layers')
# mel-post processing network parameters
model.add_argument('--postnet-n-convolution', default=5, type=float,
help='number of postnet convolutions')
model.add_argument('--postnet-kernel-size', default=5, type=float,
help='postnet kernel size')
model.add_argument('--postnet-embedding-dim', default=512, type=float,
help='postnet embedding dimension')
model.add_argument('--gate-threshold', default=0.5, type=float,
help='probability threshold for stop token')
# audio parameters
audio = parser.add_argument_group('audio parameters')
audio.add_argument('--sample-rate', default=22050, type=int,
help='Sampling rate')
audio.add_argument('--n-fft', default=1024, type=int,
help='Filter length for STFT')
audio.add_argument('--hop-length', default=256, type=int,
help='Hop (stride) length')
audio.add_argument('--win-length', default=1024, type=int,
help='Window length')
audio.add_argument('--n-mels', default=80, type=int,
help='')
audio.add_argument('--mel-fmin', default=0.0, type=float,
help='Minimum mel frequency')
audio.add_argument('--mel-fmax', default=8000.0, type=float,
help='Maximum mel frequency')
return parser
def adjust_learning_rate(epoch, optimizer, learning_rate,
anneal_steps, anneal_factor):
"""Adjust learning rate base on the initial setting."""
p = 0
if anneal_steps is not None:
for _, a_step in enumerate(anneal_steps):
if epoch >= int(a_step):
p = p + 1
if anneal_factor == 0.3:
lr = learning_rate * ((0.1 ** (p // 2)) * (1.0 if p % 2 == 0 else 0.3))
else:
lr = learning_rate * (anneal_factor ** p)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def to_gpu(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
def batch_to_gpu(batch):
text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths, gate_padded = batch
text_padded = to_gpu(text_padded).long()
text_lengths = to_gpu(text_lengths).long()
mel_specgram_padded = to_gpu(mel_specgram_padded).float()
gate_padded = to_gpu(gate_padded).float()
mel_specgram_lengths = to_gpu(mel_specgram_lengths).long()
x = (text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths)
y = (mel_specgram_padded, gate_padded)
return x, y
def training_step(model, train_batch, batch_idx):
(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths), y = batch_to_gpu(train_batch)
y_pred = model(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths)
y[0].requires_grad = False
y[1].requires_grad = False
losses = Tacotron2Loss()(y_pred[:3], y)
return losses[0] + losses[1] + losses[2], losses
def validation_step(model, val_batch, batch_idx):
(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths), y = batch_to_gpu(val_batch)
y_pred = model(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths)
losses = Tacotron2Loss()(y_pred[:3], y)
return losses[0] + losses[1] + losses[2], losses
def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
if rt.is_floating_point():
rt = rt / world_size
else:
rt = rt // world_size
return rt
def log_additional_info(writer, model, loader, epoch):
model.eval()
data = next(iter(loader))
with torch.no_grad():
(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths), _ = batch_to_gpu(data)
y_pred = model(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths)
mel_out, mel_out_postnet, gate_out, alignment = y_pred
fig = plt.figure()
ax = plt.gca()
ax.imshow(mel_out[0].cpu().numpy())
writer.add_figure("trn/mel_out", fig, epoch)
fig = plt.figure()
ax = plt.gca()
ax.imshow(mel_out_postnet[0].cpu().numpy())
writer.add_figure("trn/mel_out_postnet", fig, epoch)
writer.add_image("trn/gate_out", torch.tile(gate_out[:1], (10, 1)), epoch, dataformats="HW")
writer.add_image("trn/alignment", alignment[0], epoch, dataformats="HW")
def get_datasets(args):
text_preprocessor = partial(
text_to_sequence,
symbol_list=args.text_preprocessor,
phonemizer=args.phonemizer,
checkpoint=args.phonemizer_checkpoint,
cmudict_root=args.cmudict_root,
)
transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
f_min=args.mel_fmin,
f_max=args.mel_fmax,
n_mels=args.n_mels,
mel_scale='slaney',
normalized=False,
power=1,
norm='slaney',
),
SpectralNormalization()
)
trainset, valset = split_process_dataset(
args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor)
return trainset, valset
def train(rank, world_size, args):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
if rank == 0 and args.logging_dir:
if not os.path.isdir(args.logging_dir):
os.makedirs(args.logging_dir)
filehandler = logging.FileHandler(os.path.join(args.logging_dir, 'train.log'))
filehandler.setLevel(logging.INFO)
logger.addHandler(filehandler)
writer = SummaryWriter(log_dir=args.logging_dir)
else:
writer = None
torch.manual_seed(0)
torch.cuda.set_device(rank)
symbols = get_symbol_list(args.text_preprocessor)
model = Tacotron2(
mask_padding=args.mask_padding,
n_mels=args.n_mels,
n_symbol=len(symbols),
n_frames_per_step=args.n_frames_per_step,
symbol_embedding_dim=args.symbols_embedding_dim,
encoder_embedding_dim=args.encoder_embedding_dim,
encoder_n_convolution=args.encoder_n_convolution,
encoder_kernel_size=args.encoder_kernel_size,
decoder_rnn_dim=args.decoder_rnn_dim,
decoder_max_step=args.decoder_max_step,
decoder_dropout=args.decoder_dropout,
decoder_early_stopping=(not args.decoder_no_early_stopping),
attention_rnn_dim=args.attention_rnn_dim,
attention_hidden_dim=args.attention_hidden_dim,
attention_location_n_filter=args.attention_location_n_filter,
attention_location_kernel_size=args.attention_location_kernel_size,
attention_dropout=args.attention_dropout,
prenet_dim=args.prenet_dim,
postnet_n_convolution=args.postnet_n_convolution,
postnet_kernel_size=args.postnet_kernel_size,
postnet_embedding_dim=args.postnet_embedding_dim,
gate_threshold=args.gate_threshold,
).cuda(rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
optimizer = Adam(model.parameters(), lr=args.learning_rate)
best_loss = float("inf")
start_epoch = 0
if args.checkpoint_path and os.path.isfile(args.checkpoint_path):
logger.info(f"Checkpoint: loading '{args.checkpoint_path}'")
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(args.checkpoint_path, map_location=map_location)
start_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
logger.info(
f"Checkpoint: loaded '{args.checkpoint_path}' at epoch {checkpoint['epoch']}"
)
trainset, valset = get_datasets(args)
train_sampler = torch.utils.data.distributed.DistributedSampler(
trainset,
shuffle=True,
num_replicas=world_size,
rank=rank,
)
val_sampler = torch.utils.data.distributed.DistributedSampler(
valset,
shuffle=False,
num_replicas=world_size,
rank=rank,
)
loader_params = {
"batch_size": args.batch_size,
"num_workers": args.workers,
"prefetch_factor": 1024,
'persistent_workers': True,
"shuffle": False,
"pin_memory": True,
"drop_last": False,
"collate_fn": partial(text_mel_collate_fn, n_frames_per_step=args.n_frames_per_step),
}
train_loader = DataLoader(trainset, sampler=train_sampler, **loader_params)
val_loader = DataLoader(valset, sampler=val_sampler, **loader_params)
dist.barrier()
for epoch in range(start_epoch, args.epochs):
start = time()
model.train()
trn_loss, counts = 0, 0
if rank == 0:
iterator = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
else:
iterator = enumerate(train_loader)
for i, batch in iterator:
adjust_learning_rate(epoch, optimizer, args.learning_rate,
args.anneal_steps, args.anneal_factor)
model.zero_grad()
loss, losses = training_step(model, batch, i)
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(), args.grad_clip)
optimizer.step()
if rank == 0 and writer:
global_iters = epoch * len(train_loader)
writer.add_scalar("trn/mel_loss", losses[0], global_iters)
writer.add_scalar("trn/mel_postnet_loss", losses[1], global_iters)
writer.add_scalar("trn/gate_loss", losses[2], global_iters)
trn_loss += loss * len(batch[0])
counts += len(batch[0])
trn_loss = trn_loss / counts
trn_loss = reduce_tensor(trn_loss, world_size)
if rank == 0:
logger.info(f"[Epoch: {epoch}] time: {time()-start}; trn_loss: {trn_loss}")
if writer:
writer.add_scalar("trn_loss", trn_loss, epoch)
if ((epoch + 1) % args.validate_and_checkpoint_freq == 0) or (epoch == args.epochs - 1):
val_start_time = time()
model.eval()
val_loss, counts = 0, 0
iterator = tqdm(enumerate(val_loader), desc=f"[Rank: {rank}; Epoch: {epoch}; Eval]", total=len(val_loader))
with torch.no_grad():
for val_batch_idx, val_batch in iterator:
val_loss = val_loss + validation_step(model, val_batch, val_batch_idx)[0] * len(val_batch[0])
counts = counts + len(val_batch[0])
val_loss = val_loss / counts
val_loss = reduce_tensor(val_loss, world_size)
if rank == 0 and writer:
writer.add_scalar("val_loss", val_loss, epoch)
log_additional_info(writer, model, val_loader, epoch)
if rank == 0:
is_best = val_loss < best_loss
best_loss = min(val_loss, best_loss)
logger.info(f"[Rank: {rank}, Epoch: {epoch}; Eval] time: {time()-val_start_time}; val_loss: {val_loss}")
logger.info(f"[Epoch: {epoch}] Saving checkpoint to {args.checkpoint_path}")
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
},
is_best,
args.checkpoint_path,
)
dist.destroy_process_group()
def main(args):
logger.info("Start time: {}".format(str(datetime.now())))
torch.manual_seed(0)
random.seed(0)
if args.master_addr is not None:
os.environ['MASTER_ADDR'] = args.master_addr
elif 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = 'localhost'
if args.master_port is not None:
os.environ['MASTER_PORT'] = args.master_port
elif 'MASTER_PORT' not in os.environ:
os.environ['MASTER_PORT'] = '17778'
device_counts = torch.cuda.device_count()
logger.info(f"# available GPUs: {device_counts}")
# download dataset is not already downloaded
if args.dataset == 'ljspeech':
if not os.path.exists(os.path.join(args.dataset_path, 'LJSpeech-1.1')):
from torchaudio.datasets import LJSPEECH
LJSPEECH(root=args.dataset_path, download=True)
if device_counts == 1:
train(0, 1, args)
else:
mp.spawn(train, args=(device_counts, args, ),
nprocs=device_counts, join=True)
logger.info(f"End time: {datetime.now()}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
parser = parse_args(parser)
args, _ = parser.parse_known_args()
main(args)
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import logging
import os
import shutil
from typing import List, Tuple, Callable
import torch
from torch import Tensor
def save_checkpoint(state, is_best, filename):
r"""Save the model to a temporary file first, then copy it to filename,
in case signals interrupt the torch.save() process.
"""
torch.save(state, filename)
logging.info(f"Checkpoint saved to {filename}")
if is_best:
path, best_filename = os.path.split(filename)
best_filename = os.path.join(path, "best_" + best_filename)
shutil.copyfile(filename, best_filename)
logging.info(f"Current best checkpoint saved to {best_filename}")
def pad_sequences(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
r"""Right zero-pad all one-hot text sequences to max input length.
Modified from https://github.com/NVIDIA/DeepLearningExamples.
"""
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(x) for x in batch]), dim=0, descending=True)
max_input_len = input_lengths[0]
text_padded = torch.LongTensor(len(batch), max_input_len)
text_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
text = batch[ids_sorted_decreasing[i]]
text_padded[i, :text.size(0)] = text
return text_padded, input_lengths
def prepare_input_sequence(texts: List[str],
text_processor: Callable[[str], List[int]]) -> Tuple[Tensor, Tensor]:
d = []
for text in texts:
d.append(torch.IntTensor(text_processor(text)[:]))
text_padded, input_lengths = pad_sequences(d)
return text_padded, input_lengths
This is an example pipeline for speech recognition using a greedy or Viterbi CTC decoder, along with the Wav2Letter model trained on LibriSpeech, see [Wav2Letter: an End-to-End ConvNet-based Speech Recognition System](https://arxiv.org/pdf/1609.03193.pdf). Wav2Letter and LibriSpeech are available in torchaudio.
### Usage
More information about each command line parameters is available with the `--help` option. An example can be invoked as follows.
```bash
DATASET_ROOT = <Top>/<level>/<folder>
DATASET_FOLDER_IN_ARCHIVE = 'LibriSpeech'
python main.py \
--reduce-lr-valid \
--dataset-root "${DATASET_ROOT}" \
--dataset-folder-in-archive "${DATASET_FOLDER_IN_ARCHIVE}" \
--dataset-train train-clean-100 train-clean-360 train-other-500 \
--dataset-valid dev-clean \
--batch-size 128 \
--learning-rate .6 \
--momentum .8 \
--weight-decay .00001 \
--clip-grad 0. \
--gamma .99 \
--hop-length 160 \
--win-length 400 \
--n-bins 13 \
--normalize \
--optimizer adadelta \
--scheduler reduceonplateau \
--epochs 40
```
With these default parameters, we get 13.3 %CER and 41.9 %WER on dev-clean after 40 epochs (character and word error rates, respectively) while training on train-clean. The tail of the output is the following.
```json
...
{"name": "train", "epoch": 40, "batch char error": 925, "batch char total": 22563, "batch char error rate": 0.040996321411159865, "epoch char error": 1135098.0, "epoch char total": 23857713.0, "epoch char error rate": 0.047577821059378154, "batch word error": 791, "batch word total": 4308, "batch word error rate": 0.18361188486536675, "epoch word error": 942906.0, "epoch word total": 4569507.0, "epoch word error rate": 0.20634742435015418, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1685, "dataset length": 132096.0, "iteration": 1032.0, "loss": 0.07428030669689178, "cumulative loss": 90.47326805442572, "average loss": 0.08766789540157531, "iteration time": 1.9895553588867188, "epoch time": 2036.8874564170837}
{"name": "train", "epoch": 40, "batch char error": 1131, "batch char total": 24260, "batch char error rate": 0.0466199505358615, "epoch char error": 1136229.0, "epoch char total": 23881973.0, "epoch char error rate": 0.04757684802675223, "batch word error": 957, "batch word total": 4657, "batch word error rate": 0.2054971011380717, "epoch word error": 943863.0, "epoch word total": 4574164.0, "epoch word error rate": 0.20634655862798099, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1641, "dataset length": 132224.0, "iteration": 1033.0, "loss": 0.08775319904088974, "cumulative loss": 90.5610212534666, "average loss": 0.08766797798012256, "iteration time": 2.108018159866333, "epoch time": 2038.99547457695}
{"name": "train", "epoch": 40, "batch char error": 1099, "batch char total": 23526, "batch char error rate": 0.0467142735696676, "epoch char error": 1137328.0, "epoch char total": 23905499.0, "epoch char error rate": 0.04757599914563591, "batch word error": 936, "batch word total": 4544, "batch word error rate": 0.20598591549295775, "epoch word error": 944799.0, "epoch word total": 4578708.0, "epoch word error rate": 0.20634620071863066, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1682, "dataset length": 132352.0, "iteration": 1034.0, "loss": 0.0791337713599205, "cumulative loss": 90.64015502482653, "average loss": 0.08765972439538348, "iteration time": 2.0329701900482178, "epoch time": 2041.0284447669983}
{"name": "train", "epoch": 40, "batch char error": 1023, "batch char total": 22399, "batch char error rate": 0.045671681771507655, "epoch char error": 1138351.0, "epoch char total": 23927898.0, "epoch char error rate": 0.04757421650660664, "batch word error": 863, "batch word total": 4318, "batch word error rate": 0.1998610467809171, "epoch word error": 945662.0, "epoch word total": 4583026.0, "epoch word error rate": 0.20634009058643787, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1644, "dataset length": 132480.0, "iteration": 1035.0, "loss": 0.07874362915754318, "cumulative loss": 90.71889865398407, "average loss": 0.08765110981061262, "iteration time": 1.9106628894805908, "epoch time": 2042.9391076564789}
{"name": "validation", "epoch": 40, "cumulative loss": 12.095281183719635, "dataset length": 2688.0, "iteration": 21.0, "batch char error": 1867, "batch char total": 14792, "batch char error rate": 0.12621687398593834, "epoch char error": 37119.0, "epoch char total": 280923.0, "epoch char error rate": 0.13213229247872194, "batch word error": 1155, "batch word total": 2841, "batch word error rate": 0.4065469904963041, "epoch word error": 22601.0, "epoch word total": 54008.0, "epoch word error rate": 0.418475040734706, "average loss": 0.575965770653316, "validation time": 24.185853481292725}
```
As can be seen in the output above, the information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line. One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`.
## Structure of pipeline
* `main.py` -- the entry point
* `ctc_decoders.py` -- the greedy CTC decoder
* `datasets.py` -- the function to split and process librispeech, a collate factory function
* `languagemodels.py` -- a class to encode and decode strings
* `metrics.py` -- the levenshtein edit distance
* `utils.py` -- functions to log metrics, save checkpoint, and count parameters
from torch import topk
class GreedyDecoder:
def __call__(self, outputs):
"""Greedy Decoder. Returns highest probability of class labels for each timestep
Args:
outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank))
Returns:
torch.Tensor: class labels per time step.
"""
_, indices = topk(outputs, k=1, dim=-1)
return indices[..., 0]
import torch
from torchaudio.datasets import LIBRISPEECH
class MapMemoryCache(torch.utils.data.Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""
def __init__(self, dataset):
self.dataset = dataset
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]
item = self.dataset[n]
self._cache[n] = item
return item
def __len__(self):
return len(self.dataset)
class Processed(torch.utils.data.Dataset):
def __init__(self, dataset, transforms, encode):
self.dataset = dataset
self.transforms = transforms
self.encode = encode
def __getitem__(self, key):
item = self.dataset[key]
return self.process_datapoint(item)
def __len__(self):
return len(self.dataset)
def process_datapoint(self, item):
transformed = item[0]
target = item[2].lower()
transformed = self.transforms(transformed)
transformed = transformed[0, ...].transpose(0, -1)
target = self.encode(target)
target = torch.tensor(target, dtype=torch.long, device=transformed.device)
return transformed, target
def split_process_librispeech(
datasets, transforms, language_model, root, folder_in_archive,
):
def create(tags, cache=True):
if isinstance(tags, str):
tags = [tags]
if isinstance(transforms, list):
transform_list = transforms
else:
transform_list = [transforms]
data = torch.utils.data.ConcatDataset(
[
Processed(
LIBRISPEECH(
root, tag, folder_in_archive=folder_in_archive, download=False,
),
transform,
language_model.encode,
)
for tag, transform in zip(tags, transform_list)
]
)
data = MapMemoryCache(data)
return data
# For performance, we cache all datasets
return tuple(create(dataset) for dataset in datasets)
def collate_factory(model_length_function, transforms=None):
if transforms is None:
transforms = torch.nn.Sequential()
def collate_fn(batch):
tensors = [transforms(b[0]) for b in batch if b]
tensors_lengths = torch.tensor(
[model_length_function(t) for t in tensors],
dtype=torch.long,
device=tensors[0].device,
)
tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
tensors = tensors.transpose(1, -1)
targets = [b[1] for b in batch if b]
target_lengths = torch.tensor(
[target.shape[0] for target in targets],
dtype=torch.long,
device=tensors.device,
)
targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
return tensors, targets, tensors_lengths, target_lengths
return collate_fn
import collections
import itertools
class LanguageModel:
def __init__(self, labels, char_blank, char_space):
self.char_space = char_space
self.char_blank = char_blank
labels = list(labels)
self.length = len(labels)
enumerated = list(enumerate(labels))
flipped = [(sub[1], sub[0]) for sub in enumerated]
d1 = collections.OrderedDict(enumerated)
d2 = collections.OrderedDict(flipped)
self.mapping = {**d1, **d2}
def encode(self, iterable):
if isinstance(iterable, list):
return [self.encode(i) for i in iterable]
else:
return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable]
def decode(self, tensor):
if len(tensor) > 0 and isinstance(tensor[0], list):
return [self.decode(t) for t in tensor]
else:
# not idempotent, since clean string
x = (self.mapping[i] for i in tensor)
x = "".join(i for i, _ in itertools.groupby(x))
x = x.replace(self.char_blank, "")
# x = x.strip()
return x
def __len__(self):
return self.length
import argparse
import logging
import os
import string
from datetime import datetime
from time import time
import torch
import torchaudio
from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from transforms import Normalize, UnsqueezeFirst
from utils import MetricLogger, count_parameters, save_checkpoint
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--type",
metavar="T",
default="mfcc",
choices=["waveform", "mfcc"],
help="input type for model",
)
parser.add_argument(
"--freq-mask",
default=0,
type=int,
metavar="N",
help="maximal width of frequency mask",
)
parser.add_argument(
"--win-length",
default=400,
type=int,
metavar="N",
help="width of spectrogram window",
)
parser.add_argument(
"--hop-length",
default=160,
type=int,
metavar="N",
help="width of spectrogram window",
)
parser.add_argument(
"--time-mask",
default=0,
type=int,
metavar="N",
help="maximal width of time mask",
)
parser.add_argument(
"--workers",
default=0,
type=int,
metavar="N",
help="number of data loading workers",
)
parser.add_argument(
"--checkpoint",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint",
)
parser.add_argument(
"--epochs",
default=200,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument(
"--reduce-lr-valid",
action="store_true",
help="reduce learning rate based on validation loss",
)
parser.add_argument(
"--normalize", action="store_true", help="normalize model input"
)
parser.add_argument(
"--progress-bar", action="store_true", help="use progress bar while training"
)
parser.add_argument(
"--decoder",
metavar="D",
default="greedy",
choices=["greedy"],
help="decoder to use",
)
parser.add_argument(
"--batch-size", default=128, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--n-bins",
default=13,
type=int,
metavar="N",
help="number of bins in transforms",
)
parser.add_argument(
"--optimizer",
metavar="OPT",
default="adadelta",
choices=["sgd", "adadelta", "adam", "adamw"],
help="optimizer to use",
)
parser.add_argument(
"--scheduler",
metavar="S",
default="reduceonplateau",
choices=["exponential", "reduceonplateau"],
help="optimizer to use",
)
parser.add_argument(
"--learning-rate",
default=0.6,
type=float,
metavar="LR",
help="initial learning rate",
)
parser.add_argument(
"--gamma",
default=0.99,
type=float,
metavar="GAMMA",
help="learning rate exponential decay constant",
)
parser.add_argument(
"--momentum", default=0.8, type=float, metavar="M", help="momentum"
)
parser.add_argument(
"--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay"
)
parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8)
parser.add_argument("--rho", metavar="RHO", type=float, default=0.95)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0)
parser.add_argument(
"--dataset-root",
type=str,
help="specify dataset root folder",
)
parser.add_argument(
"--dataset-folder-in-archive",
type=str,
help="specify dataset folder in archive",
)
parser.add_argument(
"--dataset-train",
default=["train-clean-100"],
nargs="+",
type=str,
help="select which part of librispeech to train with",
)
parser.add_argument(
"--dataset-valid",
default=["dev-clean"],
nargs="+",
type=str,
help="select which part of librispeech to validate with",
)
parser.add_argument(
"--distributed", action="store_true", help="enable DistributedDataParallel"
)
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument(
"--world-size", type=int, default=8, help="the world size to initiate DPP"
)
parser.add_argument("--jit", action="store_true", help="if used, model is jitted")
args = parser.parse_args()
logging.info(args)
return args
def setup_distributed(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
def model_length_function(tensor):
if tensor.shape[1] == 1:
# waveform mode
return int(tensor.shape[0]) // 160 // 2 + 1
return int(tensor.shape[0]) // 2 + 1
def compute_error_rates(outputs, targets, decoder, language_model, metric):
output = outputs.transpose(0, 1).to("cpu")
output = decoder(output)
# Compute CER
output = language_model.decode(output.tolist())
target = language_model.decode(targets.tolist())
print_length = 20
for i in range(2):
# Print a few examples
output_print = output[i].ljust(print_length)[:print_length]
target_print = target[i].ljust(print_length)[:print_length]
logging.info("Target: %s Output: %s", target_print, output_print)
cers = [edit_distance(t, o) for t, o in zip(target, output)]
cers = sum(cers)
n = sum(len(t) for t in target)
metric["batch char error"] = cers
metric["batch char total"] = n
metric["batch char error rate"] = cers / n
metric["epoch char error"] += cers
metric["epoch char total"] += n
metric["epoch char error rate"] = metric["epoch char error"] / metric["epoch char total"]
# Compute WER
output = [o.split(language_model.char_space) for o in output]
target = [t.split(language_model.char_space) for t in target]
wers = [edit_distance(t, o) for t, o in zip(target, output)]
wers = sum(wers)
n = sum(len(t) for t in target)
metric["batch word error"] = wers
metric["batch word total"] = n
metric["batch word error rate"] = wers / n
metric["epoch word error"] += wers
metric["epoch word total"] += n
metric["epoch word error rate"] = metric["epoch word error"] / metric["epoch word total"]
def train_one_epoch(
model,
criterion,
optimizer,
scheduler,
data_loader,
decoder,
language_model,
device,
epoch,
clip_grad,
disable_logger=False,
reduce_lr_on_plateau=False,
):
model.train()
metric = MetricLogger("train", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
data_loader, maxsize=2
):
start = time()
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# keep batch first for data parallel
outputs = model(inputs).transpose(-1, -2).transpose(0, 1)
# CTC
# outputs: input length, batch size, number of classes (including blank)
# targets: batch size, max target length
# input_lengths: batch size
# target_lengths: batch size
loss = criterion(outputs, targets, tensors_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
if clip_grad > 0:
metric["gradient"] = torch.nn.utils.clip_grad_norm_(
model.parameters(), clip_grad
)
optimizer.step()
compute_error_rates(outputs, targets, decoder, language_model, metric)
try:
metric["lr"] = scheduler.get_last_lr()[0]
except AttributeError:
metric["lr"] = optimizer.param_groups[0]["lr"]
metric["batch size"] = len(inputs)
metric["n_channel"] = inputs.shape[1]
metric["n_time"] = inputs.shape[-1]
metric["dataset length"] += metric["batch size"]
metric["iteration"] += 1
metric["loss"] = loss.item()
metric["cumulative loss"] += metric["loss"]
metric["average loss"] = metric["cumulative loss"] / metric["iteration"]
metric["iteration time"] = time() - start
metric["epoch time"] += metric["iteration time"]
metric()
if reduce_lr_on_plateau and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(metric["average loss"])
elif not isinstance(scheduler, ReduceLROnPlateau):
scheduler.step()
def evaluate(
model,
criterion,
data_loader,
decoder,
language_model,
device,
epoch,
disable_logger=False,
):
with torch.no_grad():
model.eval()
start = time()
metric = MetricLogger("validation", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(
data_loader, maxsize=2
):
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# keep batch first for data parallel
outputs = model(inputs).transpose(-1, -2).transpose(0, 1)
# CTC
# outputs: input length, batch size, number of classes (including blank)
# targets: batch size, max target length
# input_lengths: batch size
# target_lengths: batch size
metric["cumulative loss"] += criterion(
outputs, targets, tensors_lengths, target_lengths
).item()
metric["dataset length"] += len(inputs)
metric["iteration"] += 1
compute_error_rates(outputs, targets, decoder, language_model, metric)
metric["average loss"] = metric["cumulative loss"] / metric["iteration"]
metric["validation time"] = time() - start
metric()
return metric["average loss"]
def main(rank, args):
# Distributed setup
if args.distributed:
setup_distributed(rank, args.world_size)
not_main_rank = args.distributed and rank != 0
logging.info("Start time: %s", datetime.now())
# Explicitly set seed to make sure models created in separate processes
# start from same random weights and biases
torch.manual_seed(args.seed)
# Empty CUDA cache
torch.cuda.empty_cache()
# Change backend for flac files
torchaudio.set_audio_backend("soundfile")
# Transforms
melkwargs = {
"n_fft": args.win_length,
"n_mels": args.n_bins,
"hop_length": args.hop_length,
}
sample_rate_original = 16000
if args.type == "mfcc":
transforms = torch.nn.Sequential(
torchaudio.transforms.MFCC(
sample_rate=sample_rate_original,
n_mfcc=args.n_bins,
melkwargs=melkwargs,
),
)
num_features = args.n_bins
elif args.type == "waveform":
transforms = torch.nn.Sequential(UnsqueezeFirst())
num_features = 1
else:
raise ValueError("Model type not supported")
if args.normalize:
transforms = torch.nn.Sequential(transforms, Normalize())
augmentations = torch.nn.Sequential()
if args.freq_mask:
augmentations = torch.nn.Sequential(
augmentations,
torchaudio.transforms.FrequencyMasking(freq_mask_param=args.freq_mask),
)
if args.time_mask:
augmentations = torch.nn.Sequential(
augmentations,
torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask),
)
# Text preprocessing
char_blank = "*"
char_space = " "
char_apostrophe = "'"
labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
language_model = LanguageModel(labels, char_blank, char_space)
# Dataset
training, validation = split_process_librispeech(
[args.dataset_train, args.dataset_valid],
[transforms, transforms],
language_model,
root=args.dataset_root,
folder_in_archive=args.dataset_folder_in_archive,
)
# Decoder
if args.decoder == "greedy":
decoder = GreedyDecoder()
else:
raise ValueError("Selected decoder not supported")
# Model
model = Wav2Letter(
num_classes=language_model.length,
input_type=args.type,
num_features=num_features,
)
if args.jit:
model = torch.jit.script(model)
if args.distributed:
n = torch.cuda.device_count() // args.world_size
devices = list(range(rank * n, (rank + 1) * n))
model = model.to(devices[0])
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices)
else:
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
model = model.to(devices[0], non_blocking=True)
model = torch.nn.DataParallel(model)
n = count_parameters(model)
logging.info("Number of parameters: %s", n)
# Optimizer
if args.optimizer == "adadelta":
optimizer = Adadelta(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay,
eps=args.eps,
rho=args.rho,
)
elif args.optimizer == "sgd":
optimizer = SGD(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optimizer == "adam":
optimizer = Adam(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
elif args.optimizer == "adamw":
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
else:
raise ValueError("Selected optimizer not supported")
if args.scheduler == "exponential":
scheduler = ExponentialLR(optimizer, gamma=args.gamma)
elif args.scheduler == "reduceonplateau":
scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3)
else:
raise ValueError("Selected scheduler not supported")
criterion = torch.nn.CTCLoss(
blank=language_model.mapping[char_blank], zero_infinity=False
)
# Data Loader
collate_fn_train = collate_factory(model_length_function, augmentations)
collate_fn_valid = collate_factory(model_length_function)
loader_training_params = {
"num_workers": args.workers,
"pin_memory": True,
"shuffle": True,
"drop_last": True,
}
loader_validation_params = loader_training_params.copy()
loader_validation_params["shuffle"] = False
loader_training = DataLoader(
training,
batch_size=args.batch_size,
collate_fn=collate_fn_train,
**loader_training_params,
)
loader_validation = DataLoader(
validation,
batch_size=args.batch_size,
collate_fn=collate_fn_valid,
**loader_validation_params,
)
# Setup checkpoint
best_loss = 1.0
load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint)
if args.distributed:
torch.distributed.barrier()
if load_checkpoint:
logging.info("Checkpoint: loading %s", args.checkpoint)
checkpoint = torch.load(args.checkpoint)
args.start_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
logging.info(
"Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
)
else:
logging.info("Checkpoint: not found")
save_checkpoint(
{
"epoch": args.start_epoch,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
False,
args.checkpoint,
not_main_rank,
)
if args.distributed:
torch.distributed.barrier()
torch.autograd.set_detect_anomaly(False)
for epoch in range(args.start_epoch, args.epochs):
logging.info("Epoch: %s", epoch)
train_one_epoch(
model,
criterion,
optimizer,
scheduler,
loader_training,
decoder,
language_model,
devices[0],
epoch,
args.clip_grad,
not_main_rank,
not args.reduce_lr_valid,
)
loss = evaluate(
model,
criterion,
loader_validation,
decoder,
language_model,
devices[0],
epoch,
not_main_rank,
)
if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
is_best = loss < best_loss
best_loss = min(loss, best_loss)
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
},
is_best,
args.checkpoint,
not_main_rank,
)
logging.info("End time: %s", datetime.now())
if args.distributed:
torch.distributed.destroy_process_group()
def spawn_main(main, args):
if args.distributed:
torch.multiprocessing.spawn(
main, args=(args,), nprocs=args.world_size, join=True
)
else:
main(0, args)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = parse_args()
spawn_main(main, args)
import torch
class Normalize(torch.nn.Module):
def forward(self, tensor):
return (tensor - tensor.mean(-1, keepdim=True)) / tensor.std(-1, keepdim=True)
class UnsqueezeFirst(torch.nn.Module):
def forward(self, tensor):
return tensor.unsqueeze(0)
import json
import logging
import os
import shutil
from collections import defaultdict
import torch
class MetricLogger(defaultdict):
def __init__(self, name, print_freq=1, disable=False):
super().__init__(lambda: 0.0)
self.disable = disable
self.print_freq = print_freq
self._iter = 0
self["name"] = name
def __str__(self):
return json.dumps(self)
def __call__(self):
self._iter = (self._iter + 1) % self.print_freq
if not self.disable and not self._iter:
print(self, flush=True)
def save_checkpoint(state, is_best, filename, disable):
"""
Save the model to a temporary file first,
then copy it to filename, in case the signal interrupts
the torch.save() process.
"""
if disable:
return
if filename == "":
return
tempfile = filename + ".temp"
# Remove tempfile in case interuption during the copying from tempfile to filename
if os.path.isfile(tempfile):
os.remove(tempfile)
torch.save(state, tempfile)
if os.path.isfile(tempfile):
os.rename(tempfile, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")
logging.warning("Checkpoint: saved")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was
introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio.
### Usage
An example can be invoked as follows.
```
python main.py \
--batch-size 256 \
--learning-rate 1e-4 \
--n-freq 80 \
--loss 'crossentropy' \
--n-bits 8 \
```
For inference, an example can be invoked as follows.
Please refer to the [documentation](https://pytorch.org/audio/master/models.html#id10) for
available checkpoints.
```
python inference.py \
--checkpoint-name wavernn_10k_epochs_8bits_ljspeech \
--output-wav-path ./output.wav
```
This example would generate a file named `output.wav` in the current working directory.
### Output
The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file.
```python
def read_json(filename):
"""
Convert the standard output saved to filename into a pandas dataframe for analysis.
"""
import pandas
import json
with open(filename, "r") as f:
data = f.read()
# pandas doesn't read single quotes for json
data = data.replace("'", '"')
data = [json.loads(l) for l in data.splitlines()]
return pandas.DataFrame(data)
```
import random
import torch
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
"""
def __init__(self, dataset):
self.dataset = dataset
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]
item = self.dataset[n]
self._cache[n] = item
return item
def __len__(self):
return len(self.dataset)
class Processed(torch.utils.data.Dataset):
def __init__(self, dataset, transforms):
self.dataset = dataset
self.transforms = transforms
def __getitem__(self, key):
item = self.dataset[key]
return self.process_datapoint(item)
def __len__(self):
return len(self.dataset)
def process_datapoint(self, item):
specgram = self.transforms(item[0])
return item[0].squeeze(0), specgram
def split_process_dataset(args, transforms):
if args.dataset == 'ljspeech':
data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
elif args.dataset == 'libritts':
train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False)
val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False)
else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms)
train_dataset = MapMemoryCache(train_dataset)
val_dataset = MapMemoryCache(val_dataset)
return train_dataset, val_dataset
def collate_factory(args):
def raw_collate(batch):
pad = (args.kernel_size - 1) // 2
# input waveform length
wave_length = args.hop_length * args.seq_len_factor
# input spectrogram length
spec_length = args.seq_len_factor + pad * 2
# max start postion in spectrogram
max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]
# random start postion in spectrogram
spec_offsets = [random.randint(0, offset) for offset in max_offsets]
# random start postion in waveform
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]
waveform_combine = [
x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1]
for i, x in enumerate(batch)
]
specgram = [
x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length]
for i, x in enumerate(batch)
]
specgram = torch.stack(specgram)
waveform_combine = torch.stack(waveform_combine)
waveform = waveform_combine[:, :wave_length]
target = waveform_combine[:, 1:]
# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
if args.loss == "crossentropy":
if args.mulaw:
mulaw_encode = MuLawEncoding(2 ** args.n_bits)
waveform = mulaw_encode(waveform)
target = mulaw_encode(target)
waveform = bits_to_normalized_waveform(waveform, args.n_bits)
else:
target = normalized_waveform_to_bits(target, args.n_bits)
return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)
return raw_collate
import argparse
import torch
import torchaudio
from torchaudio.transforms import MelSpectrogram
from torchaudio.models import wavernn
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS
from torchaudio.datasets import LJSPEECH
from wavernn_inference_wrapper import WaveRNNInferenceWrapper
from processing import NormalizeDB
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-wav-path", default="./output.wav", type=str, metavar="PATH",
help="The path to output the reconstructed wav file.",
)
parser.add_argument(
"--jit", default=False, action="store_true",
help="If used, the model and inference function is jitted."
)
parser.add_argument(
"--no-batch-inference", default=False, action="store_true",
help="Don't use batch inference."
)
parser.add_argument(
"--no-mulaw", default=False, action="store_true",
help="Don't use mulaw decoder to decoder the signal."
)
parser.add_argument(
"--checkpoint-name", default="wavernn_10k_epochs_8bits_ljspeech",
choices=list(_MODEL_CONFIG_AND_URLS.keys()),
help="Select the WaveRNN checkpoint."
)
parser.add_argument(
"--batch-timesteps", default=100, type=int,
help="The time steps for each batch. Only used when batch inference is used",
)
parser.add_argument(
"--batch-overlap", default=5, type=int,
help="The overlapping time steps between batches. Only used when batch inference is used",
)
args = parser.parse_args()
return args
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0]
mel_kwargs = {
'sample_rate': sample_rate,
'n_fft': 2048,
'f_min': 40.,
'n_mels': 80,
'win_length': 1100,
'hop_length': 275,
'mel_scale': 'slaney',
'norm': 'slaney',
'power': 1,
}
transforms = torch.nn.Sequential(
MelSpectrogram(**mel_kwargs),
NormalizeDB(min_level_db=-100, normalization=True),
)
mel_specgram = transforms(waveform)
wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
if args.jit:
wavernn_inference_model = torch.jit.script(wavernn_inference_model)
with torch.no_grad():
output = wavernn_inference_model(mel_specgram.to(device),
mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps,
overlap=args.batch_overlap,)
torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
if __name__ == "__main__":
args = parse_args()
main(args)
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
class LongCrossEntropyLoss(nn.Module):
r""" CrossEntropy loss
"""
def __init__(self):
super(LongCrossEntropyLoss, self).__init__()
def forward(self, output, target):
output = output.transpose(1, 2)
target = target.long()
criterion = nn.CrossEntropyLoss()
return criterion(output, target)
class MoLLoss(nn.Module):
r""" Discretized mixture of logistic distributions loss
Adapted from wavenet vocoder
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)
Args:
y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
y (Tensor): Target (n_batch x n_time x 1)
num_classes (int): Number of classes
log_scale_min (float): Log scale minimum value
reduce (bool): If True, the losses are averaged or summed for each minibatch
Returns
Tensor: loss
"""
def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
super(MoLLoss, self).__init__()
self.num_classes = num_classes
self.log_scale_min = log_scale_min
self.reduce = reduce
def forward(self, y_hat, y):
y = y.unsqueeze(-1)
if self.log_scale_min is None:
self.log_scale_min = math.log(1e-14)
assert y_hat.dim() == 3
assert y_hat.size(-1) % 3 == 0
nr_mix = y_hat.size(-1) // 3
# unpack parameters (n_batch, n_time, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix: 2 * nr_mix]
log_scales = torch.clamp(
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min
)
# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
y = y.expand_as(means)
centered_y = y - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
cdf_min = torch.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log(
torch.clamp(cdf_delta, min=1e-12)
) + (1.0 - inner_inner_cond) * (
log_pdf_mid - math.log((self.num_classes - 1) / 2)
)
inner_cond = (y > 0.999).float()
inner_out = (
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
)
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if self.reduce:
return -torch.mean(_log_sum_exp(log_probs))
else:
return -_log_sum_exp(log_probs).unsqueeze(-1)
def _log_sum_exp(x):
r""" Numerically stable log_sum_exp implementation that prevents overflow
"""
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
import argparse
import logging
import os
from collections import defaultdict
from datetime import datetime
from time import time
from typing import List
import torch
import torchaudio
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--workers",
default=4,
type=int,
metavar="N",
help="number of data loading workers",
)
parser.add_argument(
"--checkpoint",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint",
)
parser.add_argument(
"--epochs",
default=8000,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--start-epoch", default=0, type=int, metavar="N", help="manual epoch number"
)
parser.add_argument(
"--print-freq",
default=10,
type=int,
metavar="N",
help="print frequency in epochs",
)
parser.add_argument(
"--dataset",
default="ljspeech",
choices=["ljspeech", "libritts"],
type=str,
help="select dataset to train with",
)
parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate",
)
parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0)
parser.add_argument(
"--mulaw",
default=True,
action="store_true",
help="if used, waveform is mulaw encoded",
)
parser.add_argument(
"--jit", default=False, action="store_true", help="if used, model is jitted"
)
parser.add_argument(
"--upsample-scales",
default=[5, 5, 11],
type=List[int],
help="the list of upsample scales",
)
parser.add_argument(
"--n-bits", default=8, type=int, help="the bits of output waveform",
)
parser.add_argument(
"--sample-rate",
default=22050,
type=int,
help="the rate of audio dimensions (samples per second)",
)
parser.add_argument(
"--hop-length",
default=275,
type=int,
help="the number of samples between the starts of consecutive frames",
)
parser.add_argument(
"--win-length", default=1100, type=int, help="the length of the STFT window",
)
parser.add_argument(
"--f-min", default=40.0, type=float, help="the minimum frequency",
)
parser.add_argument(
"--min-level-db",
default=-100,
type=float,
help="the minimum db value for spectrogam normalization",
)
parser.add_argument(
"--n-res-block", default=10, type=int, help="the number of ResBlock in stack",
)
parser.add_argument(
"--n-rnn", default=512, type=int, help="the dimension of RNN layer",
)
parser.add_argument(
"--n-fc", default=512, type=int, help="the dimension of fully connected layer",
)
parser.add_argument(
"--kernel-size",
default=5,
type=int,
help="the number of kernel size in the first Conv1d layer",
)
parser.add_argument(
"--n-freq", default=80, type=int, help="the number of spectrogram bins to use",
)
parser.add_argument(
"--n-hidden-melresnet",
default=128,
type=int,
help="the number of hidden dimensions of resblock in melresnet",
)
parser.add_argument(
"--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet",
)
parser.add_argument(
"--n-fft", default=2048, type=int, help="the number of Fourier bins",
)
parser.add_argument(
"--loss",
default="crossentropy",
choices=["crossentropy", "mol"],
type=str,
help="the type of loss",
)
parser.add_argument(
"--seq-len-factor",
default=5,
type=int,
help="the length of each waveform to process per batch = hop_length * seq_len_factor",
)
parser.add_argument(
"--val-ratio",
default=0.1,
type=float,
help="the ratio of waveforms for validation",
)
parser.add_argument(
"--file-path", default="", type=str, help="the path of audio files",
)
parser.add_argument(
"--normalization", default=True, action="store_true", help="if True, spectrogram is normalized",
)
args = parser.parse_args()
return args
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
model.train()
sums = defaultdict(lambda: 0.0)
start1 = time()
metric = MetricLogger("train_iteration")
metric["epoch"] = epoch
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2):
start2 = time()
waveform = waveform.to(device)
specgram = specgram.to(device)
target = target.to(device)
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)
loss = criterion(output, target)
loss_item = loss.item()
sums["loss"] += loss_item
metric["loss"] = loss_item
optimizer.zero_grad()
loss.backward()
if args.clip_grad > 0:
gradient = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.clip_grad
)
sums["gradient"] += gradient.item()
metric["gradient"] = gradient.item()
optimizer.step()
metric["iteration"] = sums["iteration"]
metric["time"] = time() - start2
metric()
sums["iteration"] += 1
avg_loss = sums["loss"] / len(data_loader)
metric = MetricLogger("train_epoch")
metric["epoch"] = epoch
metric["loss"] = sums["loss"] / len(data_loader)
metric["gradient"] = avg_loss
metric["time"] = time() - start1
metric()
def validate(model, criterion, data_loader, device, epoch):
with torch.no_grad():
model.eval()
sums = defaultdict(lambda: 0.0)
start = time()
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2):
waveform = waveform.to(device)
specgram = specgram.to(device)
target = target.to(device)
output = model(waveform, specgram)
output, target = output.squeeze(1), target.squeeze(1)
loss = criterion(output, target)
sums["loss"] += loss.item()
avg_loss = sums["loss"] / len(data_loader)
metric = MetricLogger("validation")
metric["epoch"] = epoch
metric["loss"] = avg_loss
metric["time"] = time() - start
metric()
return avg_loss
def main(args):
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
logging.info("Start time: {}".format(str(datetime.now())))
melkwargs = {
"n_fft": args.n_fft,
"power": 1,
"hop_length": args.hop_length,
"win_length": args.win_length,
}
transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_mels=args.n_freq,
f_min=args.f_min,
mel_scale='slaney',
norm='slaney',
**melkwargs,
),
NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
)
train_dataset, val_dataset = split_process_dataset(args, transforms)
loader_training_params = {
"num_workers": args.workers,
"pin_memory": False,
"shuffle": True,
"drop_last": False,
}
loader_validation_params = loader_training_params.copy()
loader_validation_params["shuffle"] = False
collate_fn = collate_factory(args)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn,
**loader_training_params,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn,
**loader_validation_params,
)
n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30
model = WaveRNN(
upsample_scales=args.upsample_scales,
n_classes=n_classes,
hop_length=args.hop_length,
n_res_block=args.n_res_block,
n_rnn=args.n_rnn,
n_fc=args.n_fc,
kernel_size=args.kernel_size,
n_freq=args.n_freq,
n_hidden=args.n_hidden_melresnet,
n_output=args.n_output_melresnet,
)
if args.jit:
model = torch.jit.script(model)
model = torch.nn.DataParallel(model)
model = model.to(devices[0], non_blocking=True)
n = count_parameters(model)
logging.info(f"Number of parameters: {n}")
# Optimizer
optimizer_params = {
"lr": args.learning_rate,
}
optimizer = Adam(model.parameters(), **optimizer_params)
criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss()
best_loss = 10.0
if args.checkpoint and os.path.isfile(args.checkpoint):
logging.info(f"Checkpoint: loading '{args.checkpoint}'")
checkpoint = torch.load(args.checkpoint)
args.start_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(
f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}"
)
else:
logging.info("Checkpoint: not found")
save_checkpoint(
{
"epoch": args.start_epoch,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
},
False,
args.checkpoint,
)
for epoch in range(args.start_epoch, args.epochs):
train_one_epoch(
model, criterion, optimizer, train_loader, devices[0], epoch,
)
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
sum_loss = validate(model, criterion, val_loader, devices[0], epoch)
is_best = sum_loss < best_loss
best_loss = min(sum_loss, best_loss)
save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"best_loss": best_loss,
"optimizer": optimizer.state_dict(),
},
is_best,
args.checkpoint,
)
logging.info(f"End time: {datetime.now()}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = parse_args()
main(args)
import torch
import torch.nn as nn
class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value
"""
def __init__(self, min_level_db, normalization):
super().__init__()
self.min_level_db = min_level_db
self.normalization = normalization
def forward(self, specgram):
specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5))
if self.normalization:
return torch.clamp(
(self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
)
return specgram
def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]
"""
assert abs(waveform).max() <= 1.0
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
return torch.clamp(waveform, 0, 2 ** bits - 1).int()
def bits_to_normalized_waveform(label: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]
"""
return 2 * label / (2 ** bits - 1.0) - 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment