Commit 51782715 authored by liugh5's avatar liugh5
Browse files

update

parent 8b4e9acd
import os
import sys
import logging
import torch
from collections import defaultdict
from tensorboardX import SummaryWriter
from tqdm import tqdm
import soundfile as sf
import numpy as np
from kantts.utils.plot import plot_spectrogram, plot_alignment
def traversal_dict(d, func):
if not isinstance(d, dict):
logging.error("Not a dict: {}".format(d))
return
for k, v in d.items():
if isinstance(v, dict):
traversal_dict(v, func)
else:
func(k, v)
def distributed_init():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("RANK", 0))
distributed = world_size > 1
device = torch.device("cuda", local_rank)
if distributed:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
logging.info(
"Distributed training, global world size: {}, local world size: {}, global rank: {}, local rank: {}".format(
world_size,
torch.cuda.device_count(),
torch.distributed.get_rank(),
local_rank,
)
)
logging.info("nccl backend: {}".format(torch.distributed.is_nccl_available()))
logging.info("mpi backend: {}".format(torch.distributed.is_mpi_available()))
device_ids = list(range(torch.cuda.device_count()))
logging.info(
"[{}] rank = {}, world_size = {}, n_gpus = {}, device_ids = {}".format(
os.getpid(),
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
torch.cuda.device_count(),
device_ids,
)
)
return distributed, device, local_rank, world_size
class Trainer(object):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.criterion = criterion
self.device = device
self.sampler = sampler
self.train_loader = train_loader
self.valid_loader = valid_loader
self.max_epochs = max_epochs
self.steps = 1
self.epoch = 0
self.save_dir = save_dir
self.save_interval = save_interval
self.valid_interval = valid_interval
self.log_interval = log_interval
self.grad_clip = grad_clip
self.total_train_loss = defaultdict(float)
self.total_eval_loss = defaultdict(float)
self.config = config
self.distributed = self.config.get("distributed", False)
self.rank = self.config.get("rank", 0)
self.log_dir = os.path.join(save_dir, "log")
self.ckpt_dir = os.path.join(save_dir, "ckpt")
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
self.writer = SummaryWriter(self.log_dir)
if max_epochs is None:
self.max_epochs = sys.maxsize
else:
self.max_epochs = int(max_epochs)
if max_steps is None:
self.max_steps = sys.maxsize
else:
self.max_steps = int(max_steps)
self.finish_training = False
def set_model_state(self, state="train"):
if state == "train":
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].train()
else:
self.model.train()
elif state == "eval":
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].eval()
else:
self.model.eval()
else:
raise ValueError("state must be either 'train' or 'eval'.")
def write_to_tensorboard(self, loss):
"""Write to tensorboard."""
for key, value in loss.items():
self.writer.add_scalar(key, value, self.steps)
# FIXME: an example for simple feedforward model
def save_checkpoint(self, checkpoint_path):
state_dict = {
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"steps": self.steps,
"model": self.model.state_dict(),
}
# TODO: distributed training
if not os.path.exists(checkpoint_path):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path)
self.model.load_state_dict(state_dict["model"], strict=strict)
if restore_training_state:
if "optimizer" in state_dict:
self.optimizer["KanTtsSAMBERT"].load_state_dict(state_dict["optimizer"])
if "scheduler" in state_dict:
self.scheduler["KanTtsSAMBERT"].load_state_dict(state_dict["scheduler"])
if "steps" in state_dict:
self.steps = state_dict["steps"]
# TODO
def check_save_interval(self):
if self.ckpt_dir is not None and (self.steps) % self.save_interval == 0:
self.save_checkpoint(
os.path.join(self.ckpt_dir, "checkpoint_{}.pth".format(self.steps))
)
logging.info("Checkpoint saved at step {}".format(self.steps))
def check_log_interval(self):
if self.writer is not None and (self.steps) % self.log_interval == 0:
for key in self.total_train_loss.keys():
self.total_train_loss[key] /= self.config["log_interval_steps"]
logging.info(
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
)
self.write_to_tensorboard(self.total_train_loss)
self.total_train_loss = defaultdict(float)
def log_learning_rate(key, sche):
logging.info("{} learning rate: {:.6f}".format(key, sche.get_lr()[0]))
self.write_to_tensorboard({"{}_lr".format(key): sche.get_lr()[0]})
traversal_dict(self.scheduler, log_learning_rate)
def check_eval_interval(self):
if self.valid_interval > 0 and (self.steps) % self.valid_interval == 0:
self.eval_epoch()
def check_stop_training(self):
if self.steps >= self.max_steps or self.epoch >= self.max_epochs:
self.finish_training = True
def train(self):
self.set_model_state("train")
while True:
self.train_epoch()
self.epoch += 1
self.check_stop_training()
if self.finish_training:
break
def train_epoch(self):
for batch in tqdm(self.train_loader):
self.train_step(batch)
if self.rank == 0:
self.check_eval_interval()
self.check_save_interval()
self.check_log_interval()
self.steps += 1
self.check_stop_training()
if self.finish_training:
break
logging.info("Epoch {} finished".format(self.epoch))
if self.distributed:
self.sampler["train"].set_epoch(self.epoch)
# TODO: implement train_step() for specific model
def train_step(self, batch):
data, target = batch
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
# TODO: implement eval_epoch() for specific model
@torch.no_grad()
def eval_step(self, batch):
pass
def eval_epoch(self):
logging.info(f"(Epoch: {self.epoch}) Start evaluation.")
# change mode
self.set_model_state("eval")
self.total_eval_loss = defaultdict(float)
# TODO: save some intermidiate results
rand_idx = np.random.randint(0, len(self.valid_loader))
idx = 0
logging.info("Valid data size: {}".format(len(self.valid_loader)))
for batch in tqdm(self.valid_loader):
self.eval_step(batch)
if idx == rand_idx:
logging.info(
f"(Epoch: {self.epoch}) Random batch: {idx}, generating image."
)
self.genearete_and_save_intermediate_result(batch)
idx += 1
for key in self.total_eval_loss.keys():
self.total_eval_loss[key] /= idx + 1
logging.info(
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
)
self.write_to_tensorboard(self.total_eval_loss)
logging.info("Epoch {} evaluation finished".format(self.epoch))
self.set_model_state("train")
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
pass
class GAN_Trainer(Trainer):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
super().__init__(
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs,
max_steps,
save_dir,
save_interval,
valid_interval,
log_interval,
grad_clip,
)
def set_model_state(self, state="train"):
if state == "train":
if isinstance(self.model, dict):
self.model["generator"].train()
for key in self.model["discriminator"].keys():
self.model["discriminator"][key].train()
else:
self.model.train()
elif state == "eval":
if isinstance(self.model, dict):
self.model["generator"].eval()
for key in self.model["discriminator"].keys():
self.model["discriminator"][key].eval()
else:
self.model.eval()
else:
raise ValueError("state must be either 'train' or 'eval'.")
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
"""Generate and save intermediate result."""
# delayed import to avoid error related backend error
import matplotlib.pyplot as plt
# generate
y_batch, x_batch = batch
y_batch, x_batch = y_batch.to(self.device), x_batch.to(self.device)
y_batch_ = self.model["generator"](x_batch)
if self.model.get("pqmf", None):
y_mb_ = y_batch_
y_batch_ = self.model["pqmf"].synthesis(y_mb_)
# check directory
dirname = os.path.join(self.log_dir, f"predictions/{self.steps}steps")
if not os.path.exists(dirname):
os.makedirs(dirname)
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1):
# convert to ndarray
y, y_ = y.view(-1).cpu().numpy(), y_.view(-1).cpu().numpy()
# plot figure and save it
figname = os.path.join(dirname, f"{idx}.png")
plt.subplot(2, 1, 1)
plt.plot(y)
plt.title("groundtruth speech")
plt.subplot(2, 1, 2)
plt.plot(y_)
plt.title(f"generated speech @ {self.steps} steps")
plt.tight_layout()
plt.savefig(figname)
plt.close()
# save as wavfile
y = np.clip(y, -1, 1)
y_ = np.clip(y_, -1, 1)
sf.write(
figname.replace(".png", "_ref.wav"),
y,
self.config["audio_config"]["sampling_rate"],
"PCM_16",
)
sf.write(
figname.replace(".png", "_gen.wav"),
y_,
self.config["audio_config"]["sampling_rate"],
"PCM_16",
)
if idx >= self.config["num_save_intermediate_results"]:
break
@torch.no_grad()
def eval_step(self, batch):
y, x = batch
y, x = y.to(self.device), x.to(self.device)
y_ = self.model["generator"](x)
# reconstruct the signal from multi-band signal
if self.model.get("pqmf", None):
y_mb_ = y_
y_ = self.model["pqmf"].synthesis(y_mb_)
aux_loss = 0.0
# multi-resolution sfft loss
if self.criterion.get("stft_loss", None):
sc_loss, mag_loss = self.criterion["stft_loss"](y_, y)
aux_loss += (sc_loss + mag_loss) * self.criterion["stft_loss"].weights
self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item()
# subband multi-resolution stft loss
if self.criterion.get("subband_stft_loss", None):
aux_loss *= 0.5 # for balancing with subband stft loss
y_mb = self.model["pqmf"].analysis(y)
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
self.total_eval_loss[
"eval/sub_spectral_convergence_loss"
] += sub_sc_loss.item()
self.total_eval_loss[
"eval/sub_log_stft_magnitude_loss"
] += sub_mag_loss.item()
aux_loss += (
0.5 * (sub_sc_loss + sub_mag_loss) * self.criterion["sub_stft"].weights
)
# mel spectrogram loss
if self.criterion.get("mel_loss", None):
mel_loss = self.criterion["mel_loss"](y_, y)
aux_loss += mel_loss * self.criterion["mel_loss"].weights
self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
fmap_lst_ = []
adv_loss = 0.0
# adversiral loss
for discriminator in self.model["discriminator"].keys():
p_, fmap_ = self.model["discriminator"][discriminator](y_)
fmap_lst_.append(fmap_)
adv_loss += (
self.criterion["generator_adv_loss"](p_)
* self.criterion["generator_adv_loss"].weights
)
gen_loss = aux_loss + adv_loss
if self.criterion.get("feat_match_loss", None):
fmap_lst = []
# no need to track gradients
# TODO: implement feature matching loss
for discriminator in self.model["discriminator"].keys():
with torch.no_grad():
p, fmap = self.model["discriminator"][discriminator](y)
fmap_lst.append(fmap)
fm_loss = 0.0
for fmap_, fmap in zip(fmap_lst, fmap_lst_):
fm_loss += self.criterion["feat_match_loss"](fmap_, fmap)
self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item()
gen_loss += fm_loss * self.criterion["feat_match_loss"].weights
dis_loss = 0.0
for discriminator in self.model["discriminator"].keys():
p, fmap = self.model["discriminator"][discriminator](y)
p_, fmap_ = self.model["discriminator"][discriminator](y_.detach())
real_loss, fake_loss = self.criterion["discriminator_adv_loss"](p_, p)
dis_loss += real_loss + fake_loss
self.total_eval_loss["eval/real_loss"] += real_loss.item()
self.total_eval_loss["eval/fake_loss"] += fake_loss.item()
self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()
self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
def train_step(self, batch):
y, x = batch
y, x = y.to(self.device), x.to(self.device)
if self.steps >= self.config.get("generator_train_start_steps", 0):
y_ = self.model["generator"](x)
# reconstruct the signal from multi-band signal
if self.model.get("pqmf", None):
y_mb_ = y_
y_ = self.model["pqmf"].synthesis(y_mb_)
# initialize
gen_loss = 0.0
# multi-resolution sfft loss
if self.criterion.get("stft_loss", None):
sc_loss, mag_loss = self.criterion["stft_loss"](y_, y)
gen_loss += (sc_loss + mag_loss) * self.criterion["stft_loss"].weights
self.total_train_loss[
"train/spectral_convergence_loss"
] += sc_loss.item()
self.total_train_loss[
"train/log_stft_magnitude_loss"
] += mag_loss.item()
# subband multi-resolution stft loss
if self.criterion.get("subband_stft_loss", None):
gen_loss *= 0.5 # for balancing with subband stft loss
y_mb = self.model["pqmf"].analysis(y)
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
self.total_train_loss[
"train/sub_spectral_convergence_loss"
] += sub_sc_loss.item()
self.total_train_loss[
"train/sub_log_stft_magnitude_loss"
] += sub_mag_loss.item()
# mel spectrogram loss
if self.criterion.get("mel_loss", None):
mel_loss = self.criterion["mel_loss"](y_, y)
gen_loss += mel_loss * self.criterion["mel_loss"].weights
self.total_train_loss["train/mel_loss"] += mel_loss.item()
# adversarial loss
if self.steps > self.config["discriminator_train_start_steps"]:
adv_loss = 0.0
fmap_lst_ = []
for discriminator in self.model["discriminator"].keys():
p_, fmap_ = self.model["discriminator"][discriminator](y_)
fmap_lst_.append(fmap_)
adv_loss += self.criterion["generator_adv_loss"](p_)
self.total_train_loss["train/adversarial_loss"] += adv_loss.item()
gen_loss += adv_loss * self.criterion["generator_adv_loss"].weights
# feature matching loss
if self.criterion.get("feat_match_loss", None):
fmap_lst = []
# no need to track gradients
# TODO: implement feature matching loss
for discriminator in self.model["discriminator"].keys():
with torch.no_grad():
p, fmap = self.model["discriminator"][discriminator](y)
fmap_lst.append(fmap)
fm_loss = 0.0
for fmap_, fmap in zip(fmap_lst, fmap_lst_):
fm_loss += self.criterion["feat_match_loss"](fmap_, fmap)
self.total_train_loss[
"train/feature_matching_loss"
] += fm_loss.item()
gen_loss += fm_loss * self.criterion["feat_match_loss"].weights
self.total_train_loss["train/generator_loss"] += gen_loss.item()
# update generator
self.optimizer["generator"].zero_grad()
gen_loss.backward()
if self.config["generator_grad_norm"] > 0:
torch.nn.utils.clip_grad_norm_(
self.model["generator"].parameters(),
self.config["generator_grad_norm"],
)
self.optimizer["generator"].step()
self.scheduler["generator"].step()
# update discriminator
if self.steps > self.config["discriminator_train_start_steps"]:
# re-compute y_ which leads better quality
with torch.no_grad():
y_ = self.model["generator"](x)
if self.model.get("pqmf", None):
y_ = self.model["pqmf"].synthesis(y_)
# discriminator loss
dis_loss = 0.0
for discriminator in self.model["discriminator"].keys():
p, fmap = self.model["discriminator"][discriminator](y)
p_, fmap_ = self.model["discriminator"][discriminator](y_.detach())
real_loss, fake_loss = self.criterion["discriminator_adv_loss"](p_, p)
dis_loss += real_loss + fake_loss
self.total_train_loss["train/real_loss"] += real_loss.item()
self.total_train_loss["train/fake_loss"] += fake_loss.item()
self.total_train_loss["train/discriminator_loss"] += dis_loss.item()
# update discriminator
for key in self.optimizer["discriminator"].keys():
self.optimizer["discriminator"][key].zero_grad()
dis_loss.backward()
if self.config["discriminator_grad_norm"] > 0:
torch.nn.utils.clip_grad_norm_(
self.model["discriminator"].parameters(),
self.config["discriminator_grad_norm"],
)
for key in self.optimizer["discriminator"].keys():
self.optimizer["discriminator"][key].step()
for key in self.scheduler["discriminator"].keys():
self.scheduler["discriminator"][key].step()
def save_checkpoint(self, checkpoint_path):
state_dict = {
"optimizer": {
"generator": self.optimizer["generator"].state_dict(),
"discriminator": {},
},
"scheduler": {
"generator": self.scheduler["generator"].state_dict(),
"discriminator": {},
},
"steps": self.steps,
}
for model_name in self.optimizer["discriminator"].keys():
state_dict["optimizer"]["discriminator"][model_name] = self.optimizer[
"discriminator"
][model_name].state_dict()
for model_name in self.scheduler["discriminator"].keys():
state_dict["scheduler"]["discriminator"][model_name] = self.scheduler[
"discriminator"
][model_name].state_dict()
if not self.distributed:
model_state = self.model["generator"].state_dict()
else:
model_state = self.model["generator"].module.state_dict()
state_dict["model"] = {
"generator": model_state,
"discriminator": {},
}
for model_name in self.model["discriminator"].keys():
if not self.distributed:
model_state = self.model["discriminator"][model_name].state_dict()
else:
model_state = self.model["discriminator"][
model_name
].module.state_dict()
state_dict["model"]["discriminator"][model_name] = model_state
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path, map_location="cpu")
if not self.distributed:
self.model["generator"].load_state_dict(
state_dict["model"]["generator"], strict=strict
)
else:
self.model["generator"].module.load_state_dict(
state_dict["model"]["generator"], strict=strict
)
for model_name in state_dict["model"]["discriminator"]:
if not self.distributed:
self.model["discriminator"][model_name].load_state_dict(
state_dict["model"]["discriminator"][model_name], strict=strict
)
else:
self.model["discriminator"][model_name].module.load_state_dict(
state_dict["model"]["discriminator"][model_name], strict=strict
)
if restore_training_state:
if "steps" in state_dict:
self.steps = state_dict["steps"]
if "optimizer" in state_dict:
self.optimizer["generator"].load_state_dict(
state_dict["optimizer"]["generator"]
)
for model_name in state_dict["optimizer"]["discriminator"].keys():
self.optimizer["discriminator"][model_name].load_state_dict(
state_dict["optimizer"]["discriminator"][model_name]
)
if "scheduler" in state_dict:
for model_name in state_dict["scheduler"]["discriminator"].keys():
self.scheduler["discriminator"][model_name].load_state_dict(
state_dict["scheduler"]["discriminator"][model_name]
)
self.scheduler["generator"].load_state_dict(
state_dict["scheduler"]["generator"]
)
class Sambert_Trainer(Trainer):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
super().__init__(
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs,
max_steps,
save_dir,
save_interval,
valid_interval,
log_interval,
grad_clip,
)
self.with_MAS = config["Model"]["KanTtsSAMBERT"]["params"].get("MAS", False)
self.fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
inputs_emotion = batch["input_emotions"].to(self.device)
inputs_speaker = batch["input_speakers"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
mel_targets = batch["mel_targets"].to(self.device)
# generate mel spectrograms
res = self.model["KanTtsSAMBERT"](
inputs_ling[0:1],
inputs_emotion[0:1],
inputs_speaker[0:1],
valid_input_lengths[0:1],
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
enc_slf_attn_lst = res["enc_slf_attn_lst"]
pnca_x_attn_lst = res["pnca_x_attn_lst"]
pnca_h_attn_lst = res["pnca_h_attn_lst"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
dirname = os.path.join(self.log_dir, f"predictions/{self.steps}steps")
if not os.path.exists(dirname):
os.makedirs(dirname)
for layer_id, slf_attn in enumerate(enc_slf_attn_lst):
for head_id in range(
self.config["Model"]["KanTtsSAMBERT"]["params"]["encoder_num_heads"]
):
fig = plot_alignment(
slf_attn[
head_id, : valid_input_lengths[0], : valid_input_lengths[0]
]
.cpu()
.numpy(),
info="valid_len_{}".format(valid_input_lengths[0].item()),
)
fig.savefig(
os.path.join(
dirname,
"enc_slf_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
zip(pnca_x_attn_lst, pnca_h_attn_lst)
):
for head_id in range(
self.config["Model"]["KanTtsSAMBERT"]["params"]["decoder_num_heads"]
):
fig = plot_alignment(
pnca_x_attn[head_id, :, :].cpu().numpy(),
info="x_band_width_{}".format(x_band_width),
)
fig.savefig(
os.path.join(
dirname,
"pnca_x_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
fig = plot_alignment(
pnca_h_attn[head_id, :, :].cpu().numpy(),
info="h_band_width_{}".format(h_band_width),
)
fig.savefig(
os.path.join(
dirname,
"pnca_h_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
target_mel = mel_targets[0].cpu().numpy()
coarse_mel = dec_outputs.squeeze(0).cpu().numpy()
output_mel = postnet_outputs.squeeze(0).cpu().numpy()
np.save(os.path.join(dirname, "coarse_mel.npy"), coarse_mel)
np.save(os.path.join(dirname, "output_mel.npy"), output_mel)
np.save(os.path.join(dirname, "target_mel.npy"), target_mel)
fig = plot_spectrogram(coarse_mel.T)
fig.savefig(os.path.join(dirname, "mel_dec_outputs"))
fig = plot_spectrogram(output_mel.T)
fig.savefig(os.path.join(dirname, "mel_postnet_outputs"))
@torch.no_grad()
def eval_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
inputs_emotion = batch["input_emotions"].to(self.device)
inputs_speaker = batch["input_speakers"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
valid_output_lengths = batch["valid_output_lengths"].to(self.device)
mel_targets = batch["mel_targets"].to(self.device)
durations = (
batch["durations"].to(self.device)
if batch["durations"] is not None
else None
)
pitch_contours = batch["pitch_contours"].to(self.device)
energy_contours = batch["energy_contours"].to(self.device)
attn_priors = (
batch["attn_priors"].to(self.device)
if batch["attn_priors"] is not None
else None
)
fp_label = None
if self.fp_enable:
fp_label = batch["fp_label"].to(self.device)
# generate mel spectrograms
res = self.model["KanTtsSAMBERT"](
inputs_ling,
inputs_emotion,
inputs_speaker,
valid_input_lengths,
output_lengths=valid_output_lengths,
mel_targets=mel_targets,
duration_targets=durations,
pitch_targets=pitch_contours,
energy_targets=energy_contours,
attn_priors=attn_priors,
fp_label=fp_label,
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
log_duration_predictions = res["log_duration_predictions"]
pitch_predictions = res["pitch_predictions"]
energy_predictions = res["energy_predictions"]
duration_targets = res["duration_targets"]
pitch_targets = res["pitch_targets"]
energy_targets = res["energy_targets"]
fp_predictions = res["fp_predictions"]
valid_inter_lengths = res["valid_inter_lengths"]
mel_loss_, mel_loss = self.criterion["MelReconLoss"](
valid_output_lengths, mel_targets, dec_outputs, postnet_outputs
)
dur_loss, pitch_loss, energy_loss = self.criterion["ProsodyReconLoss"](
valid_inter_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
)
loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss
if self.fp_enable:
fp_loss = self.criterion["FpCELoss"](
valid_input_lengths, fp_predictions, fp_label
)
loss_total = loss_total + fp_loss
if self.with_MAS:
attn_soft = res["attn_soft"]
attn_hard = res["attn_hard"]
attn_logprob = res["attn_logprob"]
attn_ctc_loss = self.criterion["AttentionCTCLoss"](
attn_logprob, valid_input_lengths, valid_output_lengths
)
attn_kl_loss = self.criterion["AttentionBinarizationLoss"](
self.epoch, attn_hard, attn_soft
)
loss_total += attn_ctc_loss + attn_kl_loss
self.total_eval_loss["eval/attn_ctc_loss"] += attn_ctc_loss.item()
self.total_eval_loss["eval/attn_kl_loss"] += attn_kl_loss.item()
self.total_eval_loss["eval/TotalLoss"] += loss_total.item()
self.total_eval_loss["eval/mel_loss_"] += mel_loss_.item()
self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
self.total_eval_loss["eval/dur_loss"] += dur_loss.item()
self.total_eval_loss["eval/pitch_loss"] += pitch_loss.item()
self.total_eval_loss["eval/energy_loss"] += energy_loss.item()
if self.fp_enable:
self.total_eval_loss["eval/fp_loss"] += fp_loss.item()
self.total_eval_loss["eval/batch_size"] += mel_targets.size(0)
self.total_eval_loss["eval/x_band_width"] += x_band_width
self.total_eval_loss["eval/h_band_width"] += h_band_width
def train_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
inputs_emotion = batch["input_emotions"].to(self.device)
inputs_speaker = batch["input_speakers"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
valid_output_lengths = batch["valid_output_lengths"].to(self.device)
mel_targets = batch["mel_targets"].to(self.device)
durations = (
batch["durations"].to(self.device)
if batch["durations"] is not None
else None
)
pitch_contours = batch["pitch_contours"].to(self.device)
energy_contours = batch["energy_contours"].to(self.device)
attn_priors = (
batch["attn_priors"].to(self.device)
if batch["attn_priors"] is not None
else None
)
fp_label = None
if self.fp_enable:
fp_label = batch["fp_label"].to(self.device)
# generate mel spectrograms
res = self.model["KanTtsSAMBERT"](
inputs_ling,
inputs_emotion,
inputs_speaker,
valid_input_lengths,
output_lengths=valid_output_lengths,
mel_targets=mel_targets,
duration_targets=durations,
pitch_targets=pitch_contours,
energy_targets=energy_contours,
attn_priors=attn_priors,
fp_label=fp_label,
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
log_duration_predictions = res["log_duration_predictions"]
pitch_predictions = res["pitch_predictions"]
energy_predictions = res["energy_predictions"]
duration_targets = res["duration_targets"]
pitch_targets = res["pitch_targets"]
energy_targets = res["energy_targets"]
fp_predictions = res["fp_predictions"]
valid_inter_lengths = res["valid_inter_lengths"]
mel_loss_, mel_loss = self.criterion["MelReconLoss"](
valid_output_lengths, mel_targets, dec_outputs, postnet_outputs
)
dur_loss, pitch_loss, energy_loss = self.criterion["ProsodyReconLoss"](
valid_inter_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
)
loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss
if self.fp_enable:
fp_loss = self.criterion["FpCELoss"](
valid_input_lengths, fp_predictions, fp_label
)
loss_total = loss_total + fp_loss
if self.with_MAS:
attn_soft = res["attn_soft"]
attn_hard = res["attn_hard"]
attn_logprob = res["attn_logprob"]
attn_ctc_loss = self.criterion["AttentionCTCLoss"](
attn_logprob, valid_input_lengths, valid_output_lengths
)
attn_kl_loss = self.criterion["AttentionBinarizationLoss"](
self.epoch, attn_hard, attn_soft
)
loss_total += attn_ctc_loss + attn_kl_loss
self.total_train_loss["train/attn_ctc_loss"] += attn_ctc_loss.item()
self.total_train_loss["train/attn_kl_loss"] += attn_kl_loss.item()
self.total_train_loss["train/TotalLoss"] += loss_total.item()
self.total_train_loss["train/mel_loss_"] += mel_loss_.item()
self.total_train_loss["train/mel_loss"] += mel_loss.item()
self.total_train_loss["train/dur_loss"] += dur_loss.item()
self.total_train_loss["train/pitch_loss"] += pitch_loss.item()
self.total_train_loss["train/energy_loss"] += energy_loss.item()
if self.fp_enable:
self.total_train_loss["train/fp_loss"] += fp_loss.item()
self.total_train_loss["train/batch_size"] += mel_targets.size(0)
self.total_train_loss["train/x_band_width"] += x_band_width
self.total_train_loss["train/h_band_width"] += h_band_width
self.optimizer["KanTtsSAMBERT"].zero_grad()
loss_total.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
self.model["KanTtsSAMBERT"].parameters(), self.grad_clip
)
self.optimizer["KanTtsSAMBERT"].step()
self.scheduler["KanTtsSAMBERT"].step()
def save_checkpoint(self, checkpoint_path):
if not self.distributed:
model_state = self.model["KanTtsSAMBERT"].state_dict()
else:
model_state = self.model["KanTtsSAMBERT"].module.state_dict()
state_dict = {
"optimizer": self.optimizer["KanTtsSAMBERT"].state_dict(),
"scheduler": self.scheduler["KanTtsSAMBERT"].state_dict(),
"steps": self.steps,
"model": model_state,
}
if not os.path.exists(checkpoint_path):
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path)
if not self.distributed:
self.model["KanTtsSAMBERT"].load_state_dict(
state_dict["model"], strict=strict
)
else:
self.model["KanTtsSAMBERT"].module.load_state_dict(
state_dict["model"], strict=strict
)
if restore_training_state:
if "optimizer" in state_dict:
self.optimizer["KanTtsSAMBERT"].load_state_dict(state_dict["optimizer"])
if "scheduler" in state_dict:
self.scheduler["KanTtsSAMBERT"].load_state_dict(state_dict["scheduler"])
if "steps" in state_dict:
self.steps = state_dict["steps"]
class Textsy_BERT_Trainer(Trainer):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
super().__init__(
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs,
max_steps,
save_dir,
save_interval,
valid_interval,
log_interval,
grad_clip,
)
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
bert_masks = batch["bert_masks"].to(self.device)
targets = batch["targets"].to(self.device)
res = self.model["KanTtsTextsyBERT"](
inputs_ling[0:1],
valid_input_lengths[0:1],
)
logits = res["logits"]
enc_slf_attn_lst = res["enc_slf_attn_lst"]
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
dirname = os.path.join(self.log_dir, f"predictions/{self.steps}steps")
if not os.path.exists(dirname):
os.makedirs(dirname)
for layer_id, slf_attn in enumerate(enc_slf_attn_lst):
for head_id in range(
self.config["Model"]["KanTtsTextsyBERT"]["params"]["encoder_num_heads"]
):
fig = plot_alignment(
slf_attn[
head_id, : valid_input_lengths[0], : valid_input_lengths[0]
]
.cpu()
.numpy(),
info="valid_len_{}".format(valid_input_lengths[0].item()),
)
fig.savefig(
os.path.join(
dirname,
"enc_slf_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
target = targets[0].cpu().numpy()
bert_mask = bert_masks[0].cpu().numpy()
pred = preds.cpu().numpy()
np.save(os.path.join(dirname, "pred.npy"), pred)
np.save(os.path.join(dirname, "target.npy"), target)
np.save(os.path.join(dirname, "bert_mask.npy"), bert_mask)
@torch.no_grad()
def eval_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
bert_masks = batch["bert_masks"].to(self.device)
targets = batch["targets"].to(self.device)
res = self.model["KanTtsTextsyBERT"](
inputs_ling,
valid_input_lengths,
)
logits = res["logits"]
loss_total, err = self.criterion["SeqCELoss"](
logits,
targets,
bert_masks,
)
loss_total = loss_total / logits.size(-1)
self.total_eval_loss["eval/TotalLoss"] += loss_total.item()
self.total_eval_loss["eval/Error"] += err.item()
self.total_eval_loss["eval/batch_size"] += targets.size(0)
def train_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
bert_masks = batch["bert_masks"].to(self.device)
targets = batch["targets"].to(self.device)
res = self.model["KanTtsTextsyBERT"](
inputs_ling,
valid_input_lengths,
)
logits = res["logits"]
loss_total, err = self.criterion["SeqCELoss"](
logits,
targets,
bert_masks,
)
loss_total = loss_total / logits.size(-1)
self.optimizer["KanTtsTextsyBERT"].zero_grad()
loss_total.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
self.model["KanTtsTextsyBERT"].parameters(), self.grad_clip
)
self.optimizer["KanTtsTextsyBERT"].step()
self.scheduler["KanTtsTextsyBERT"].step()
self.total_train_loss["train/TotalLoss"] += loss_total.item()
self.total_train_loss["train/Error"] += err.item()
self.total_train_loss["train/batch_size"] += targets.size(0)
def save_checkpoint(self, checkpoint_path):
if not self.distributed:
model_state = self.model["KanTtsTextsyBERT"].state_dict()
else:
model_state = self.model["KanTtsTextsyBERT"].module.state_dict()
state_dict = {
"optimizer": self.optimizer["KanTtsTextsyBERT"].state_dict(),
"scheduler": self.scheduler["KanTtsTextsyBERT"].state_dict(),
"steps": self.steps,
"model": model_state,
}
if not os.path.exists(checkpoint_path):
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path)
if not self.distributed:
self.model["KanTtsTextsyBERT"].load_state_dict(
state_dict["model"], strict=strict
)
else:
self.model["KanTtsTextsyBERT"].module.load_state_dict(
state_dict["model"], strict=strict
)
if restore_training_state:
self.optimizer["KanTtsTextsyBERT"].load_state_dict(state_dict["optimizer"])
self.scheduler["KanTtsTextsyBERT"].load_state_dict(state_dict["scheduler"])
self.steps = state_dict["steps"]
import torch
import librosa
from distutils.version import LooseVersion
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
if is_pytorch_17plus:
x_stft = torch.stft(
x, fft_size, hop_size, win_length, window, return_complex=False
)
else:
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.pow(10.0, x * 0.05) / C
def spectral_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
output = dynamic_range_compression_torch(magnitudes) - ref_level_db
if symmetric:
return torch.clamp(
2 * norm_abs_value * ((output - min_level_db) / (-min_level_db))
- norm_abs_value,
min=-norm_abs_value,
max=norm_abs_value,
)
else:
return torch.clamp(
norm_abs_value * ((output - min_level_db) / (-min_level_db)),
min=0.0,
max=norm_abs_value,
)
def spectral_de_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
if symmetric:
magnitudes = torch.clamp(magnitudes, min=-norm_abs_value, max=norm_abs_value)
magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
2 * norm_abs_value
) + min_level_db
else:
magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
magnitudes = (magnitudes) * (-min_level_db) / (norm_abs_value) + min_level_db
output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
return output
class MelSpectrogram(torch.nn.Module):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
pad_mode="constant",
):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
self.pad_mode = pad_mode
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax,
)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
self.stft_params = {
"n_fft": self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center": self.center,
"normalized": self.normalized,
"onesided": self.onesided,
"pad_mode": self.pad_mode,
}
if is_pytorch_17plus:
self.stft_params["return_complex"] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
x_mel = spectral_normalize_torch(x_mel)
# return self.log(x_mel).transpose(1, 2)
return x_mel.transpose(1, 2)
import ttsfrd
ENG_LANG_MAPPING = {
"PinYin": "zh-cn",
"English": "en-us",
"British": "en-gb",
"ZhHK": "hk_cantonese",
"Sichuan": "sichuan",
"Japanese": "japanese",
"WuuShangHai": "shanghai",
"Indonesian": "indonesian",
"Malay": "malay",
"Filipino": "filipino",
"Vietnamese": "vietnamese",
"Korean": "korean",
"Russian": "russian",
}
def text_to_mit_symbols(texts, resources_dir, speaker, lang="PinYin"):
fe = ttsfrd.TtsFrontendEngine()
fe.initialize(resources_dir)
fe.set_lang_type(ENG_LANG_MAPPING[lang])
symbols_lst = []
for idx, text in enumerate(texts):
text = text.strip()
res = fe.gen_tacotron_symbols(text)
res = res.replace("F7", speaker)
sentences = res.split("\n")
for sentence in sentences:
arr = sentence.split("\t")
# skip the empty line
if len(arr) != 2:
continue
sub_index, symbols = sentence.split("\t")
symbol_str = "{}_{}\t{}\n".format(idx, sub_index, symbols)
symbols_lst.append(symbol_str)
return symbols_lst
"""
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
"""
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
emotion_types = [
"emotion_none",
"emotion_neutral",
"emotion_angry",
"emotion_disgust",
"emotion_fear",
"emotion_happy",
"emotion_sad",
"emotion_surprise",
"emotion_calm",
"emotion_gentle",
"emotion_relax",
"emotion_lyrical",
"emotion_serious",
"emotion_disgruntled",
"emotion_satisfied",
"emotion_disappointed",
"emotion_excited",
"emotion_anxiety",
"emotion_jealousy",
"emotion_hate",
"emotion_pity",
"emotion_pleasure",
"emotion_arousal",
"emotion_dominance",
"emotion_placeholder1",
"emotion_placeholder2",
"emotion_placeholder3",
"emotion_placeholder4",
"emotion_placeholder5",
"emotion_placeholder6",
"emotion_placeholder7",
"emotion_placeholder8",
"emotion_placeholder9",
]
import xml.etree.ElementTree as ET
from kantts.preprocess.languages import languages
import logging
import os
syllable_flags = [
"s_begin",
"s_end",
"s_none",
"s_both",
"s_middle",
]
word_segments = [
"word_begin",
"word_end",
"word_middle",
"word_both",
"word_none",
]
LANGUAGES_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
"preprocess",
"languages",
)
def parse_phoneset(phoneset_file):
"""Parse a phoneset file and return a list of symbols.
Args:
phoneset_file (str): Path to the phoneset file.
Returns:
list: A list of phones.
"""
ns = "{http://schemas.alibaba-inc.com/tts}"
phone_lst = []
phoneset_root = ET.parse(phoneset_file).getroot()
for phone_node in phoneset_root.findall(ns + "phone"):
phone_lst.append(phone_node.find(ns + "name").text)
for i in range(1, 5):
phone_lst.append("#{}".format(i))
return phone_lst
def parse_tonelist(tonelist_file):
"""Parse a tonelist file and return a list of tones.
Args:
tonelist_file (str): Path to the tonelist file.
Returns:
dict: A dictionary of tones.
"""
tone_lst = []
with open(tonelist_file, "r") as f:
lines = f.readlines()
for line in lines:
tone = line.strip()
if tone != "":
tone_lst.append("tone{}".format(tone))
else:
tone_lst.append("tone_none")
return tone_lst
def get_language_symbols(language):
"""Get symbols of a language.
Args:
language (str): Language name.
"""
language_dict = languages.get(language, None)
if language_dict is None:
logging.error("Language %s not supported. Using PinYin as default", language)
language_dict = languages["PinYin"]
language = "PinYin"
language_dir = os.path.join(LANGUAGES_DIR, language)
phoneset_file = os.path.join(language_dir, language_dict["phoneset_path"])
tonelist_file = os.path.join(language_dir, language_dict["tonelist_path"])
phones = parse_phoneset(phoneset_file)
tones = parse_tonelist(tonelist_file)
return phones, tones, syllable_flags, word_segments
import abc
import os
import shutil
import re
import numpy as np
from . import cleaners as cleaners
from .emotion_types import emotion_types
from .lang_symbols import get_language_symbols
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
def get_fpdict(config):
# eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file.
default_sp = config["linguistic_unit"]["speaker_list"].split(",")[0]
en_sy = f"{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{en_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}" # NOQA: E501
a_sy = f"{{ga$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{a_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}" # NOQA: E501
e_sy = f"{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{e_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}" # NOQA: E501
ling_unit = KanTtsLinguisticUnit(config)
en_lings = ling_unit.encode_symbol_sequence(en_sy)
a_lings = ling_unit.encode_symbol_sequence(a_sy)
e_lings = ling_unit.encode_symbol_sequence(e_sy)
en_ling = np.stack(en_lings, axis=1)[:3, :4]
a_ling = np.stack(a_lings, axis=1)[:3, :4]
e_ling = np.stack(e_lings, axis=1)[:3, :4]
fp_dict = {1: en_ling, 2: a_ling, 3: e_ling}
return fp_dict
class LinguisticBaseUnit(abc.ABC):
def set_config_params(self, config_params):
self.config_params = config_params
def save(self, config, config_name, path):
"""Save config to file"""
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))
class KanTtsLinguisticUnit(LinguisticBaseUnit):
def __init__(self, config):
super(KanTtsLinguisticUnit, self).__init__()
# special symbol
self._pad = "_"
self._eos = "~"
self._mask = "@[MASK]"
self.unit_config = config["linguistic_unit"]
self.lang_type = self.unit_config.get("language", "PinYin")
(
self.lang_phones,
self.lang_tones,
self.lang_syllable_flags,
self.lang_word_segments,
) = get_language_symbols(self.lang_type)
self._cleaner_names = [
x.strip() for x in self.unit_config["cleaners"].split(",")
]
_lfeat_type_list = self.unit_config["lfeat_type_list"].strip().split(",")
self._lfeat_type_list = _lfeat_type_list
self.fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
if self.fp_enable:
self._fpadd_lfeat_type_list = [_lfeat_type_list[0], _lfeat_type_list[4]]
self.build()
def using_byte(self):
return "byte_index" in self._lfeat_type_list
def get_unit_size(self):
ling_unit_size = {}
if self.using_byte():
ling_unit_size["byte_index"] = len(self.byte_index)
else:
ling_unit_size["sy"] = len(self.sy)
ling_unit_size["tone"] = len(self.tone)
ling_unit_size["syllable_flag"] = len(self.syllable_flag)
ling_unit_size["word_segment"] = len(self.word_segment)
if "emo_category" in self._lfeat_type_list:
ling_unit_size["emotion"] = len(self.emo_category)
if "speaker_category" in self._lfeat_type_list:
ling_unit_size["speaker"] = len(self.speaker)
return ling_unit_size
def build(self):
self._sub_unit_dim = {}
self._sub_unit_pad = {}
if self.using_byte():
# Export all byte indices:
self.byte_index = ["@" + str(idx) for idx in range(256)] + [
self._pad,
self._eos,
self._mask,
]
self._byte_index_to_id = {s: i for i, s in enumerate(self.byte_index)}
self._id_to_byte_index = {i: s for i, s in enumerate(self.byte_index)}
self._sub_unit_dim["byte_index"] = len(self.byte_index)
self._sub_unit_pad["byte_index"] = self._byte_index_to_id["_"]
else:
# sy sub-unit
_characters = ""
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ['@' + s for s in cmudict.valid_symbols]
_arpabet = ["@" + s for s in self.lang_phones]
# Export all symbols:
self.sy = list(_characters) + _arpabet + [self._pad, self._eos, self._mask]
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
self._sub_unit_dim["sy"] = len(self.sy)
self._sub_unit_pad["sy"] = self._sy_to_id["_"]
# tone sub-unit
_characters = ""
# Export all tones:
self.tone = (
list(_characters) + self.lang_tones + [self._pad, self._eos, self._mask]
)
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
self._sub_unit_dim["tone"] = len(self.tone)
self._sub_unit_pad["tone"] = self._tone_to_id["_"]
# syllable flag sub-unit
_characters = ""
# Export all syllable_flags:
self.syllable_flag = (
list(_characters)
+ self.lang_syllable_flags
+ [self._pad, self._eos, self._mask]
)
self._syllable_flag_to_id = {s: i for i, s in enumerate(self.syllable_flag)}
self._id_to_syllable_flag = {i: s for i, s in enumerate(self.syllable_flag)}
self._sub_unit_dim["syllable_flag"] = len(self.syllable_flag)
self._sub_unit_pad["syllable_flag"] = self._syllable_flag_to_id["_"]
# word segment sub-unit
_characters = ""
# Export all syllable_flags:
self.word_segment = (
list(_characters)
+ self.lang_word_segments
+ [self._pad, self._eos, self._mask]
)
self._word_segment_to_id = {s: i for i, s in enumerate(self.word_segment)}
self._id_to_word_segment = {i: s for i, s in enumerate(self.word_segment)}
self._sub_unit_dim["word_segment"] = len(self.word_segment)
self._sub_unit_pad["word_segment"] = self._word_segment_to_id["_"]
if "emo_category" in self._lfeat_type_list:
# emotion category sub-unit
_characters = ""
self.emo_category = (
list(_characters) + emotion_types + [self._pad, self._eos, self._mask]
)
self._emo_category_to_id = {s: i for i, s in enumerate(self.emo_category)}
self._id_to_emo_category = {i: s for i, s in enumerate(self.emo_category)}
self._sub_unit_dim["emo_category"] = len(self.emo_category)
self._sub_unit_pad["emo_category"] = self._emo_category_to_id["_"]
if "speaker_category" in self._lfeat_type_list:
# speaker category sub-unit
_characters = ""
_ch_speakers = self.unit_config["speaker_list"].strip().split(",")
# Export all syllable_flags:
self.speaker = (
list(_characters) + _ch_speakers + [self._pad, self._eos, self._mask]
)
self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)}
self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)}
self._sub_unit_dim["speaker_category"] = len(self._speaker_to_id)
self._sub_unit_pad["speaker_category"] = self._speaker_to_id["_"]
def encode_symbol_sequence(self, lfeat_symbol):
lfeat_symbol = lfeat_symbol.strip().split(" ")
lfeat_symbol_separate = [""] * int(len(self._lfeat_type_list))
for this_lfeat_symbol in lfeat_symbol:
this_lfeat_symbol = this_lfeat_symbol.strip("{").strip("}").split("$")
# if len(this_lfeat_symbol) > len(self._lfeat_type_list):
# raise Exception(
# 'Length of this_lfeat_symbol in training data is longer than the length of lfeat_type_list, '\
# + str( len(this_lfeat_symbol))\
# + ' VS. '\
# + str(len(self._lfeat_type_list)))
index = 0
while index < len(lfeat_symbol_separate):
lfeat_symbol_separate[index] = (
lfeat_symbol_separate[index] + this_lfeat_symbol[index] + " "
)
index = index + 1
input_and_label_data = []
index = 0
while index < len(self._lfeat_type_list):
sequence = self.encode_sub_unit(
lfeat_symbol_separate[index].strip(), self._lfeat_type_list[index]
)
sequence_array = np.asarray(sequence, dtype=np.int32)
input_and_label_data.append(sequence_array)
index = index + 1
# # lfeat_type = 'emo_category'
# input_and_label_data.append(lfeat_symbol_separate[index].strip())
# index = index + 1
#
# # lfeat_type = 'speaker'
# input_and_label_data.append(lfeat_symbol_separate[index].strip())
return input_and_label_data
def decode_symbol_sequence(self, sequence):
result = []
for i, lfeat_type in enumerate(self._lfeat_type_list):
s = ""
sequence_item = sequence[i].tolist()
if lfeat_type == "sy":
s = self.decode_sy(sequence_item)
elif lfeat_type == "byte_index":
s = self.decode_byte_index(sequence_item)
elif lfeat_type == "tone":
s = self.decode_tone(sequence_item)
elif lfeat_type == "syllable_flag":
s = self.decode_syllable_flag(sequence_item)
elif lfeat_type == "word_segment":
s = self.decode_word_segment(sequence_item)
elif lfeat_type == "emo_category":
s = self.decode_emo_category(sequence_item)
elif lfeat_type == "speaker_category":
s = self.decode_speaker_category(sequence_item)
else:
raise Exception("Unknown lfeat type: %s" % lfeat_type)
result.append("%s:%s" % (lfeat_type, s))
return result
def encode_sub_unit(self, this_lfeat_symbol, lfeat_type):
sequence = []
if lfeat_type == "sy":
this_lfeat_symbol = this_lfeat_symbol.strip().split(" ")
this_lfeat_symbol_format = ""
index = 0
while index < len(this_lfeat_symbol):
this_lfeat_symbol_format = (
this_lfeat_symbol_format
+ "{"
+ this_lfeat_symbol[index]
+ "}"
+ " "
)
index = index + 1
sequence = self.encode_text(this_lfeat_symbol_format, self._cleaner_names)
elif lfeat_type == "byte_index":
sequence = self.encode_byte_index(this_lfeat_symbol)
elif lfeat_type == "tone":
sequence = self.encode_tone(this_lfeat_symbol)
elif lfeat_type == "syllable_flag":
sequence = self.encode_syllable_flag(this_lfeat_symbol)
elif lfeat_type == "word_segment":
sequence = self.encode_word_segment(this_lfeat_symbol)
elif lfeat_type == "emo_category":
sequence = self.encode_emo_category(this_lfeat_symbol)
elif lfeat_type == "speaker_category":
sequence = self.encode_speaker_category(this_lfeat_symbol)
else:
raise Exception("Unknown lfeat type: %s" % lfeat_type)
return sequence
def encode_text(self, text, cleaner_names):
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += self.encode_sy(_clean_text(text, cleaner_names))
break
sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names))
sequence += self.encode_arpanet(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(self._sy_to_id["~"])
return sequence
def encode_sy(self, sy):
return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)]
def decode_sy(self, id):
s = self._id_to_sy[id]
if len(s) > 1 and s[0] == "@":
s = s[1:]
return s
def should_keep_sy(self, s):
return s in self._sy_to_id and s != "_" and s != "~"
def encode_arpanet(self, text):
return self.encode_sy(["@" + s for s in text.split()])
def encode_byte_index(self, byte_index):
byte_indices = ["@" + s for s in byte_index.strip().split(" ")]
sequence = []
for this_byte_index in byte_indices:
sequence.append(self._byte_index_to_id[this_byte_index])
sequence.append(self._byte_index_to_id["~"])
return sequence
def decode_byte_index(self, id):
s = self._id_to_byte_index[id]
if len(s) > 1 and s[0] == "@":
s = s[1:]
return s
def encode_tone(self, tone):
tones = tone.strip().split(" ")
sequence = []
for this_tone in tones:
sequence.append(self._tone_to_id[this_tone])
sequence.append(self._tone_to_id["~"])
return sequence
def decode_tone(self, id):
return self._id_to_tone[id]
def encode_syllable_flag(self, syllable_flag):
syllable_flags = syllable_flag.strip().split(" ")
sequence = []
for this_syllable_flag in syllable_flags:
sequence.append(self._syllable_flag_to_id[this_syllable_flag])
sequence.append(self._syllable_flag_to_id["~"])
return sequence
def decode_syllable_flag(self, id):
return self._id_to_syllable_flag[id]
def encode_word_segment(self, word_segment):
word_segments = word_segment.strip().split(" ")
sequence = []
for this_word_segment in word_segments:
sequence.append(self._word_segment_to_id[this_word_segment])
sequence.append(self._word_segment_to_id["~"])
return sequence
def decode_word_segment(self, id):
return self._id_to_word_segment[id]
def encode_emo_category(self, emo_type):
emo_categories = emo_type.strip().split(" ")
sequence = []
for this_category in emo_categories:
sequence.append(self._emo_category_to_id[this_category])
sequence.append(self._emo_category_to_id["~"])
return sequence
def decode_emo_category(self, id):
return self._id_to_emo_category[id]
def encode_speaker_category(self, speaker):
speakers = speaker.strip().split(" ")
sequence = []
for this_speaker in speakers:
sequence.append(self._speaker_to_id[this_speaker])
sequence.append(self._speaker_to_id["~"])
return sequence
def decode_speaker_category(self, id):
return self._id_to_speaker[id]
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(
num, andword="", zero="oh", group=2
).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
import logging
import subprocess
def logging_to_file(log_file):
logger = logging.getLogger()
handler = logging.FileHandler(log_file)
formatter = logging.Formatter(
"%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def get_git_revision_short_hash():
return (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode("ascii")
.strip()
)
def get_git_revision_hash():
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment