Unverified Commit 084455a3 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add phoneme text preprocessing for Tacotron2 (#1668)

parent 8094751f
...@@ -5,7 +5,7 @@ This is an example pipeline for text-to-speech using Tacotron2. ...@@ -5,7 +5,7 @@ This is an example pipeline for text-to-speech using Tacotron2.
Required packages Required packages
```bash ```bash
pip install librosa tqdm inflect pip install librosa tqdm inflect joblib
``` ```
To use tensorboard To use tensorboard
...@@ -13,7 +13,7 @@ To use tensorboard ...@@ -13,7 +13,7 @@ To use tensorboard
pip install tensorboard pillow pip install tensorboard pillow
``` ```
## Training Tacotron2 ## Training Tacotron2 with character as input
The training of Tacotron2 can be invoked with the following command. The training of Tacotron2 can be invoked with the following command.
...@@ -26,7 +26,7 @@ python train.py \ ...@@ -26,7 +26,7 @@ python train.py \
--batch-size 96 \ --batch-size 96 \
--weight-decay 1e-6 \ --weight-decay 1e-6 \
--grad-clip 1.0 \ --grad-clip 1.0 \
--text-preprocessor character \ --text-preprocessor english_characters \
--logging-dir ./logs \ --logging-dir ./logs \
--checkpoint-path ./ckpt.pth \ --checkpoint-path ./ckpt.pth \
--dataset-path ./ --dataset-path ./
...@@ -42,4 +42,102 @@ be in `./logs`. ...@@ -42,4 +42,102 @@ be in `./logs`.
If `./ckpt.pth` already exist, this script will automatically load the file and try to continue If `./ckpt.pth` already exist, this script will automatically load the file and try to continue
training from the checkpoint. training from the checkpoint.
This command takes around 36 hours to train on 8 NVIDIA Tesla V100 GPUs. This command takes around 36 hours to train on 8 NVIDIA Tesla V100 GPUs.
\ No newline at end of file
To train the Tacotron2 model to work with the [pretrained wavernn](https://pytorch.org/audio/main/models.html#id10)
with checkpoint_name `"wavernn_10k_epochs_8bits_ljspeech"`, please run the following command instead.
```bash
python train.py
--learning-rate 1e-3 \
--epochs 1501 \
--anneal-steps 500 1000 1500 \
--anneal-factor 0.1 \
--sample-rate 22050 \
--n-fft 2048 \
--hop-length 275 \
--win-length 1100 \
--mel-fmin 40 \
--mel-fmax 11025 \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor english_characters \
--logging-dir ./wavernn_logs \
--checkpoint-path ./ckpt_wavernn.pth \
--dataset-path ./
```
## Training Tacotron2 with phoneme as input
#### Dependencies
This example use the [DeepPhonemizer](https://github.com/as-ideas/DeepPhonemizer) as
the phonemizer (the function to turn text into phonemes),
please install it with the following command (the code is tested with version 0.0.15).
```bash
pip install deep-phonemizer==0.0.15
```
Then download the model weights from [their website](https://github.com/as-ideas/DeepPhonemizer)
The link to the checkpoint that is tested with this example is
[https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt).
#### Running training script
The training of Tacotron2 with english phonemes as input can be invoked with the following command.
```bash
python train.py \
--workers 12 \
--learning-rate 1e-3 \
--epochs 1501 \
--anneal-steps 500 1000 1500 \
--anneal-factor 0.1 \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor english_phonemes \
--phonemizer DeepPhonemizer \
--phonemizer-checkpoint ./en_us_cmudict_forward.pt \
--cmudict-root ./ \
--logging-dir ./english_phonemes_logs \
--checkpoint-path ./english_phonemes_ckpt.pth \
--dataset-path ./
```
Similar to the previous examples, this command will save the log in the directory `./english_phonemes_logs`
and the checkpoint will be saved to `./english_phonemes_ckpt.pth`.
To train the Tacotron2 model with english phonemes that works with the
[pretrained wavernn](https://pytorch.org/audio/main/models.html#id10)
with checkpoint_name `"wavernn_10k_epochs_8bits_ljspeech"`, please run the following command.
```bash
python train.py \
--workers 12 \
--learning-rate 1e-3 \
--epochs 1501 \
--anneal-steps 500 1000 1500 \
--anneal-factor 0.1 \
--sample-rate 22050 \
--n-fft 2048 \
--hop-length 275 \
--win-length 1100 \
--mel-fmin 40 \
--mel-fmax 11025 \
--batch-size 96 \
--weight-decay 1e-6 \
--grad-clip 1.0 \
--text-preprocessor english_phonemes \
--phonemizer DeepPhonemizer \
--phonemizer-checkpoint ./en_us_cmudict_forward.pt \
--cmudict-root ./ \
--logging-dir ./english_phonemes_wavernn_logs \
--checkpoint-path ./english_phonemes_wavernn_ckpt.pth \
--dataset-path ./
```
...@@ -24,10 +24,11 @@ ...@@ -24,10 +24,11 @@
Modified from https://github.com/keithito/tacotron Modified from https://github.com/keithito/tacotron
""" """
from typing import List from typing import List, Union, Optional
import re import re
from unidecode import unidecode from unidecode import unidecode
from torchaudio.datasets import CMUDict
from .numbers import normalize_numbers from .numbers import normalize_numbers
...@@ -63,18 +64,87 @@ _special = '-' ...@@ -63,18 +64,87 @@ _special = '-'
_letters = 'abcdefghijklmnopqrstuvwxyz' _letters = 'abcdefghijklmnopqrstuvwxyz'
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters)
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _phonemizer = None
def text_to_sequence(sent: str) -> List[int]: available_symbol_set = set(["english_characters", "english_phonemes"])
available_phonemizers = set(["DeepPhonemizer"])
def get_symbol_list(symbol_list: str = "english_characters",
cmudict_root: Optional[str] = "./") -> List[str]:
if symbol_list == "english_characters":
return [_pad] + list(_special) + list(_punctuation) + list(_letters)
elif symbol_list == "english_phonemes":
return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols
else:
raise ValueError(f"The `symbol_list` {symbol_list} is not supported."
f"Supported `symbol_list` includes {available_symbol_set}.")
def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]:
if phonemizer == "DeepPhonemizer":
from dp.phonemizer import Phonemizer
global _phonemizer
_other_symbols = ''.join(list(_special) + list(_punctuation))
_phone_symbols_re = r'(\[[A-Z]+?\]|' + '[' + _other_symbols + '])' # [\[([A-Z]+?)\]|[-!'(),.:;? ]]
if _phonemizer is None:
# using a global variable so that we don't have to relode checkpoint
# everytime this function is called
_phonemizer = Phonemizer.from_checkpoint(checkpoint)
# Example:
# sent = "hello world!"
# '[HH][AH][L][OW] [W][ER][L][D]!'
sent = _phonemizer(sent, lang='en_us')
# ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!']
ret = re.findall(_phone_symbols_re, sent)
# ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!']
ret = [r.replace("[", "").replace("]", "") for r in ret]
return ret
else:
raise ValueError(f"The `phonemizer` {phonemizer} is not supported. "
"Supported `symbol_list` includes `'DeepPhonemizer'`.")
def text_to_sequence(sent: str,
symbol_list: Union[str, List[str]] = "english_characters",
phonemizer: Optional[str] = "DeepPhonemizer",
checkpoint: Optional[str] = "./en_us_cmudict_forward.pt",
cmudict_root: Optional[str] = "./") -> List[int]:
r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
sent (str): The input sentence to convert to a sequence. sent (str): The input sentence to convert to a sequence.
symbol_list (str or List of string, optional): When the input is a string, available options include
"english_characters" and "english_phonemes". When the input is a list of string, ``symbol_list`` will
directly be used as the symbol to encode. (Default: "english_characters")
phonemizer (str, optional): The phonemizer to use. Only used when ``symbol_list`` is "english_phonemes".
Available options include "DeepPhonemizer". (Default: "DeepPhonemizer")
checkpoint (str, optional): The path to the checkpoint of the phonemizer. Only used when ``symbol_list`` is
"english_phonemes". (Default: "./en_us_cmudict_forward.pt")
cmudict_root (str, optional): The path to the directory where the CMUDict dataset is found or downloaded.
Only used when ``symbol_list`` is "english_phonemes". (Default: "./")
Returns: Returns:
List of integers corresponding to the symbols in the sentence. List of integers corresponding to the symbols in the sentence.
Examples:
>>> text_to_sequence("hello world!", "english_characters")
[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2]
>>> text_to_sequence("hello world!", "english_phonemes")
[54, 20, 65, 69, 11, 92, 44, 65, 38, 2]
''' '''
if symbol_list == "english_phonemes":
if any(param is None for param in [phonemizer, checkpoint, cmudict_root]):
raise ValueError(
"When `symbol_list` is 'english_phonemes', "
"all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided.")
sent = unidecode(sent) # convert to ascii sent = unidecode(sent) # convert to ascii
sent = sent.lower() # lower case sent = sent.lower() # lower case
sent = normalize_numbers(sent) # expand numbers sent = normalize_numbers(sent) # expand numbers
...@@ -82,4 +152,13 @@ def text_to_sequence(sent: str) -> List[int]: ...@@ -82,4 +152,13 @@ def text_to_sequence(sent: str) -> List[int]:
sent = re.sub(regex, replacement, sent) sent = re.sub(regex, replacement, sent)
sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace
if isinstance(symbol_list, list):
symbols = symbol_list
elif isinstance(symbol_list, str):
symbols = get_symbol_list(symbol_list, cmudict_root=cmudict_root)
if symbol_list == "english_phonemes":
sent = word_to_phonemes(sent, phonemizer=phonemizer, checkpoint=checkpoint)
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
return [_symbol_to_id[s] for s in sent if s in _symbol_to_id] return [_symbol_to_id[s] for s in sent if s in _symbol_to_id]
...@@ -50,8 +50,14 @@ import matplotlib.pyplot as plt ...@@ -50,8 +50,14 @@ import matplotlib.pyplot as plt
plt.switch_backend('agg') plt.switch_backend('agg')
from datasets import text_mel_collate_fn, split_process_dataset, SpectralNormalization from datasets import text_mel_collate_fn, split_process_dataset, SpectralNormalization
from utils import save_checkpoint, get_text_preprocessor from utils import save_checkpoint
from loss import Tacotron2Loss from loss import Tacotron2Loss
from text.text_preprocessing import (
available_symbol_set,
available_phonemizers,
get_symbol_list,
text_to_sequence,
)
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
...@@ -76,13 +82,22 @@ def parse_args(parser): ...@@ -76,13 +82,22 @@ def parse_args(parser):
parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1, parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1,
help='factor for annealing learning rate') help='factor for annealing learning rate')
parser.add_argument('--text-preprocessor', default='character', type=str,
choices=['character'], help='[string] Select text preprocessor to use.')
parser.add_argument('--master-addr', default=None, type=str, parser.add_argument('--master-addr', default=None, type=str,
help='The address to use for distributed training.') help='the address to use for distributed training')
parser.add_argument('--master-port', default=None, type=str, parser.add_argument('--master-port', default=None, type=str,
help='The port to use for distributed training.') help='the port to use for distributed training')
preprocessor = parser.add_argument_group('text preprocessor setup')
preprocessor.add_argument('--text-preprocessor', default='english_characters', type=str,
choices=available_symbol_set,
help='select text preprocessor to use.')
preprocessor.add_argument('--phonemizer', type=str, choices=available_phonemizers,
help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"')
preprocessor.add_argument('--phonemizer-checkpoint', type=str,
help='the path or name of the checkpoint for the phonemizer, '
'only used when text-preprocessor is "english_phonemes"')
preprocessor.add_argument('--cmudict-root', default="./", type=str,
help='the root directory for storing cmudictionary files')
# training # training
training = parser.add_argument_group('training setup') training = parser.add_argument_group('training setup')
...@@ -263,6 +278,36 @@ def log_additional_info(writer, model, loader, epoch): ...@@ -263,6 +278,36 @@ def log_additional_info(writer, model, loader, epoch):
writer.add_image("trn/alignment", alignment[0], epoch, dataformats="HW") writer.add_image("trn/alignment", alignment[0], epoch, dataformats="HW")
def get_datasets(args):
text_preprocessor = partial(
text_to_sequence,
symbol_list=args.text_preprocessor,
phonemizer=args.phonemizer,
checkpoint=args.phonemizer_checkpoint,
cmudict_root=args.cmudict_root,
)
transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
f_min=args.mel_fmin,
f_max=args.mel_fmax,
n_mels=args.n_mels,
mel_scale='slaney',
normalized=False,
power=1,
norm='slaney',
),
SpectralNormalization()
)
trainset, valset = split_process_dataset(
args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor)
return trainset, valset
def train(rank, world_size, args): def train(rank, world_size, args):
dist.init_process_group("nccl", rank=rank, world_size=world_size) dist.init_process_group("nccl", rank=rank, world_size=world_size)
...@@ -281,7 +326,7 @@ def train(rank, world_size, args): ...@@ -281,7 +326,7 @@ def train(rank, world_size, args):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
symbols, text_preprocessor = get_text_preprocessor(args.text_preprocessor) symbols = get_symbol_list(args.text_preprocessor)
model = Tacotron2( model = Tacotron2(
mask_padding=args.mask_padding, mask_padding=args.mask_padding,
...@@ -330,24 +375,7 @@ def train(rank, world_size, args): ...@@ -330,24 +375,7 @@ def train(rank, world_size, args):
f"Checkpoint: loaded '{args.checkpoint_path}' at epoch {checkpoint['epoch']}" f"Checkpoint: loaded '{args.checkpoint_path}' at epoch {checkpoint['epoch']}"
) )
transforms = torch.nn.Sequential( trainset, valset = get_datasets(args)
torchaudio.transforms.MelSpectrogram(
sample_rate=args.sample_rate,
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
f_min=args.mel_fmin,
f_max=args.mel_fmax,
n_mels=args.n_mels,
mel_scale='slaney',
normalized=False,
power=1,
norm='slaney',
),
SpectralNormalization()
)
trainset, valset = split_process_dataset(
args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor)
train_sampler = torch.utils.data.distributed.DistributedSampler( train_sampler = torch.utils.data.distributed.DistributedSampler(
trainset, trainset,
...@@ -365,6 +393,8 @@ def train(rank, world_size, args): ...@@ -365,6 +393,8 @@ def train(rank, world_size, args):
loader_params = { loader_params = {
"batch_size": args.batch_size, "batch_size": args.batch_size,
"num_workers": args.workers, "num_workers": args.workers,
"prefetch_factor": 1024,
'persistent_workers': True,
"shuffle": False, "shuffle": False,
"pin_memory": True, "pin_memory": True,
"drop_last": False, "drop_last": False,
...@@ -484,7 +514,8 @@ def main(args): ...@@ -484,7 +514,8 @@ def main(args):
if device_counts == 1: if device_counts == 1:
train(0, 1, args) train(0, 1, args)
else: else:
mp.spawn(train, args=(device_counts, args, ), nprocs=device_counts, join=True) mp.spawn(train, args=(device_counts, args, ),
nprocs=device_counts, join=True)
logger.info(f"End time: {datetime.now()}") logger.info(f"End time: {datetime.now()}")
......
...@@ -74,14 +74,3 @@ def prepare_input_sequence(texts: List[str], ...@@ -74,14 +74,3 @@ def prepare_input_sequence(texts: List[str],
text_padded, input_lengths = pad_sequences(d) text_padded, input_lengths = pad_sequences(d)
return text_padded, input_lengths return text_padded, input_lengths
def get_text_preprocessor(preprocessor_name: str) -> Tuple[List[str], Callable[[str], List[int]]]:
if preprocessor_name == "character":
from text.text_preprocessing import symbols
from text.text_preprocessing import text_to_sequence
text_preprocessor = text_to_sequence
else:
raise ValueError("The preprocessor_name ({preprocessor_name}) provided is not supported.")
return symbols, text_preprocessor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment