Unverified Commit 084455a3 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add phoneme text preprocessing for Tacotron2 (#1668)

parent 8094751f
......@@ -5,7 +5,7 @@ This is an example pipeline for text-to-speech using Tacotron2.
Required packages
```bash
pip install librosa tqdm inflect
pip install librosa tqdm inflect joblib
```
To use tensorboard
......@@ -13,7 +13,7 @@ To use tensorboard
pip install tensorboard pillow
```
## Training Tacotron2
## Training Tacotron2 with character as input
The training of Tacotron2 can be invoked with the following command.
......@@ -26,7 +26,7 @@ python train.py \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor character \
--text-preprocessor english_characters \
--logging-dir ./logs \
--checkpoint-path ./ckpt.pth \
--dataset-path ./
......@@ -42,4 +42,102 @@ be in `./logs`.
If `./ckpt.pth` already exist, this script will automatically load the file and try to continue
training from the checkpoint.
This command takes around 36 hours to train on 8 NVIDIA Tesla V100 GPUs.
\ No newline at end of file
This command takes around 36 hours to train on 8 NVIDIA Tesla V100 GPUs.
To train the Tacotron2 model to work with the [pretrained wavernn](https://pytorch.org/audio/main/models.html#id10)
with checkpoint_name `"wavernn_10k_epochs_8bits_ljspeech"`, please run the following command instead.
```bash
python train.py
--learning-rate 1e-3 \
--epochs 1501 \
--anneal-steps 500 1000 1500 \
--anneal-factor 0.1 \
--sample-rate 22050 \
--n-fft 2048 \
--hop-length 275 \
--win-length 1100 \
--mel-fmin 40 \
--mel-fmax 11025 \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor english_characters \
--logging-dir ./wavernn_logs \
--checkpoint-path ./ckpt_wavernn.pth \
--dataset-path ./
```
## Training Tacotron2 with phoneme as input
#### Dependencies
This example use the [DeepPhonemizer](https://github.com/as-ideas/DeepPhonemizer) as
the phonemizer (the function to turn text into phonemes),
please install it with the following command (the code is tested with version 0.0.15).
```bash
pip install deep-phonemizer==0.0.15
```
Then download the model weights from [their website](https://github.com/as-ideas/DeepPhonemizer)
The link to the checkpoint that is tested with this example is
[https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt).
#### Running training script
The training of Tacotron2 with english phonemes as input can be invoked with the following command.
```bash
python train.py \
--workers 12 \
--learning-rate 1e-3 \
--epochs 1501 \
--anneal-steps 500 1000 1500 \
--anneal-factor 0.1 \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor english_phonemes \
--phonemizer DeepPhonemizer \
--phonemizer-checkpoint ./en_us_cmudict_forward.pt \
--cmudict-root ./ \
--logging-dir ./english_phonemes_logs \
--checkpoint-path ./english_phonemes_ckpt.pth \
--dataset-path ./
```
Similar to the previous examples, this command will save the log in the directory `./english_phonemes_logs`
and the checkpoint will be saved to `./english_phonemes_ckpt.pth`.
To train the Tacotron2 model with english phonemes that works with the
[pretrained wavernn](https://pytorch.org/audio/main/models.html#id10)
with checkpoint_name `"wavernn_10k_epochs_8bits_ljspeech"`, please run the following command.
```bash
python train.py \
--workers 12 \
--learning-rate 1e-3 \
--epochs 1501 \
--anneal-steps 500 1000 1500 \
--anneal-factor 0.1 \
--sample-rate 22050 \
--n-fft 2048 \
--hop-length 275 \
--win-length 1100 \
--mel-fmin 40 \
--mel-fmax 11025 \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor english_phonemes \
--phonemizer DeepPhonemizer \
--phonemizer-checkpoint ./en_us_cmudict_forward.pt \
--cmudict-root ./ \
--logging-dir ./english_phonemes_wavernn_logs \
--checkpoint-path ./english_phonemes_wavernn_ckpt.pth \
--dataset-path ./
```
......@@ -24,10 +24,11 @@
Modified from https://github.com/keithito/tacotron
"""
from typing import List
from typing import List, Union, Optional
import re
from unidecode import unidecode
from torchaudio.datasets import CMUDict
from .numbers import normalize_numbers
......@@ -63,18 +64,87 @@ _special = '-'
_letters = 'abcdefghijklmnopqrstuvwxyz'
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_phonemizer = None
def text_to_sequence(sent: str) -> List[int]:
available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"])
def get_symbol_list(symbol_list: str = "english_characters",
cmudict_root: Optional[str] = "./") -> List[str]:
if symbol_list == "english_characters":
return [_pad] + list(_special) + list(_punctuation) + list(_letters)
elif symbol_list == "english_phonemes":
return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
else:
raise ValueError(f"The `symbol_list` {symbol_list} is not supported."
f"Supported `symbol_list` includes {available_symbol_set}.")
def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
if phonemizer == "DeepPhonemizer":
from dp.phonemizer import Phonemizer
global _phonemizer
_other_symbols = ''.join(list(_special) + list(_punctuation))
_phone_symbols_re = r'(\[[A-Z]+?\]|' + '[' + _other_symbols + '])' # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
if _phonemizer is None:
# using a global variable so that we don't have to relode checkpoint
# everytime this function is called
_phonemizer = Phonemizer.from_checkpoint(checkpoint)
# Example:
# sent = "hello world!"
# '[HH][AH][L][OW] [W][ER][L][D]!'
sent = _phonemizer(sent, lang='en_us')
# ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
ret = re.findall(_phone_symbols_re, sent)
# ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!']
ret = [r.replace("[", "").replace("]", "") for r in ret]
return ret
else:
raise ValueError(f"The `phonemizer` {phonemizer} is not supported. "
"Supported `symbol_list` includes `'DeepPhonemizer'`.")
def text_to_sequence(sent: str,
symbol_list: Union[str, List[str]] = "english_characters",
phonemizer: Optional[str] = "DeepPhonemizer",
checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
cmudict_root: Optional[str] = "./") -> List[int]:
r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
Args:
sent (str): The input sentence to convert to a sequence.
symbol_list (str or List of string, optional): When the input is a string, available options include
"english_characters" and "english_phonemes". When the input is a list of string, ``symbol_list`` will
directly be used as the symbol to encode. (Default: "english_characters")
phonemizer (str, optional): The phonemizer to use. Only used when ``symbol_list`` is "english_phonemes".
Available options include "DeepPhonemizer". (Default: "DeepPhonemizer")
checkpoint (str, optional): The path to the checkpoint of the phonemizer. Only used when ``symbol_list`` is
"english_phonemes". (Default: "./en_us_cmudict_forward.pt")
cmudict_root (str, optional): The path to the directory where the CMUDict dataset is found or downloaded.
Only used when ``symbol_list`` is "english_phonemes". (Default: "./")
Returns:
Returns:
List of integers corresponding to the symbols in the sentence.
Examples:
>>> text_to_sequence("hello world!", "english_characters")
[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
>>> text_to_sequence("hello world!", "english_phonemes")
[54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
'''
if symbol_list == "english_phonemes":
if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
raise ValueError(
"When `symbol_list` is 'english_phonemes', "
"all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided.")
sent = unidecode(sent) # convert to ascii
sent = sent.lower() # lower case
sent = normalize_numbers(sent) # expand numbers
......@@ -82,4 +152,13 @@ def text_to_sequence(sent: str) -> List[int]:
sent = re.sub(regex, replacement, sent)
sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace
if isinstance(symbol_list, list):
symbols = symbol_list
elif isinstance(symbol_list, str):
symbols = get_symbol_list(symbol_list, cmudict_root=cmudict_root)
if symbol_list == "english_phonemes":
sent = word_to_phonemes(sent, phonemizer=phonemizer, checkpoint=checkpoint)
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
return [_symbol_to_id[s] for s in sent if s in _symbol_to_id]
......@@ -50,8 +50,14 @@ import matplotlib.pyplot as plt
plt.switch_backend('agg')
from datasets import text_mel_collate_fn, split_process_dataset, SpectralNormalization
from utils import save_checkpoint, get_text_preprocessor
from utils import save_checkpoint
from loss import Tacotron2Loss
from text.text_preprocessing import (
available_symbol_set,
available_phonemizers,
get_symbol_list,
text_to_sequence,
)
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
......@@ -76,13 +82,22 @@ def parse_args(parser):
parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1,
help='factor for annealing learning rate')
parser.add_argument('--text-preprocessor', default='character', type=str,
choices=['character'], help='[string] Select text preprocessor to use.')
parser.add_argument('--master-addr', default=None, type=str,
help='The address to use for distributed training.')
help='the address to use for distributed training')
parser.add_argument('--master-port', default=None, type=str,
help='The port to use for distributed training.')
help='the port to use for distributed training')
preprocessor = parser.add_argument_group('text preprocessor setup')
preprocessor.add_argument('--text-preprocessor', default='english_characters', type=str,
choices=available_symbol_set,
help='select text preprocessor to use.')
preprocessor.add_argument('--phonemizer', type=str, choices=available_phonemizers,
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"')
preprocessor.add_argument('--phonemizer-checkpoint', type=str,
help='the path or name of the checkpoint for the phonemizer, '
'only used when text-preprocessor is "english_phonemes"')
preprocessor.add_argument('--cmudict-root', default="./", type=str,
help='the root directory for storing cmudictionary files')
# training
training = parser.add_argument_group('training setup')
......@@ -263,6 +278,36 @@ def log_additional_info(writer, model, loader, epoch):
writer.add_image("trn/alignment", alignment[0], epoch, dataformats="HW")
def get_datasets(args):
text_preprocessor = partial(
text_to_sequence,
symbol_list=args.text_preprocessor,
phonemizer=args.phonemizer,
checkpoint=args.phonemizer_checkpoint,
cmudict_root=args.cmudict_root,
)
transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
f_min=args.mel_fmin,
f_max=args.mel_fmax,
n_mels=args.n_mels,
mel_scale='slaney',
normalized=False,
power=1,
norm='slaney',
),
SpectralNormalization()
)
trainset, valset = split_process_dataset(
args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor)
return trainset, valset
def train(rank, world_size, args):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
......@@ -281,7 +326,7 @@ def train(rank, world_size, args):
torch.cuda.set_device(rank)
symbols, text_preprocessor = get_text_preprocessor(args.text_preprocessor)
symbols = get_symbol_list(args.text_preprocessor)
model = Tacotron2(
mask_padding=args.mask_padding,
......@@ -330,24 +375,7 @@ def train(rank, world_size, args):
f"Checkpoint: loaded '{args.checkpoint_path}' at epoch {checkpoint['epoch']}"
)
transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
f_min=args.mel_fmin,
f_max=args.mel_fmax,
n_mels=args.n_mels,
mel_scale='slaney',
normalized=False,
power=1,
norm='slaney',
),
SpectralNormalization()
)
trainset, valset = split_process_dataset(
args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor)
trainset, valset = get_datasets(args)
train_sampler = torch.utils.data.distributed.DistributedSampler(
trainset,
......@@ -365,6 +393,8 @@ def train(rank, world_size, args):
loader_params = {
"batch_size": args.batch_size,
"num_workers": args.workers,
"prefetch_factor": 1024,
'persistent_workers': True,
"shuffle": False,
"pin_memory": True,
"drop_last": False,
......@@ -484,7 +514,8 @@ def main(args):
if device_counts == 1:
train(0, 1, args)
else:
mp.spawn(train, args=(device_counts, args, ), nprocs=device_counts, join=True)
mp.spawn(train, args=(device_counts, args, ),
nprocs=device_counts, join=True)
logger.info(f"End time: {datetime.now()}")
......
......@@ -74,14 +74,3 @@ def prepare_input_sequence(texts: List[str],
text_padded, input_lengths = pad_sequences(d)
return text_padded, input_lengths
def get_text_preprocessor(preprocessor_name: str) -> Tuple[List[str], Callable[[str], List[int]]]:
if preprocessor_name == "character":
from text.text_preprocessing import symbols
from text.text_preprocessing import text_to_sequence
text_preprocessor = text_to_sequence
else:
raise ValueError("The preprocessor_name ({preprocessor_name}) provided is not supported.")
return symbols, text_preprocessor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment