Commit ee10550a authored by liugh5's avatar liugh5
Browse files

Initial commit

parents
Pipeline #790 canceled with stages
import logging
import os
import sys
import argparse
import yaml
import time
import zipfile
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
try:
from kantts.datasets.dataset import BERT_Text_Dataset
from kantts.utils.log import logging_to_file, get_git_revision_hash
from kantts.utils.ling_unit import text_to_mit_symbols as text_to_symbols
except ImportError:
raise ImportError("Please install kantts.")
logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
def gen_metafile(
output_dir,
split_ratio=0.98,
):
raw_metafile = os.path.join(output_dir, "raw_metafile.txt")
bert_train_meta = os.path.join(output_dir, "bert_train.lst")
bert_valid_meta = os.path.join(output_dir, "bert_valid.lst")
if not os.path.exists(
bert_train_meta) or not os.path.exists(bert_valid_meta):
BERT_Text_Dataset.gen_metafile(raw_metafile, output_dir, split_ratio)
logging.info("BERT Text metafile generated.")
# TODO: Zh-CN as default
def process_mit_style_data(
text_file,
resources_zip_file,
output_dir,
):
os.makedirs(output_dir, exist_ok=True)
logging_to_file(os.path.join(output_dir, "data_process_stdout.log"))
resource_root_dir = os.path.dirname(resources_zip_file)
resource_dir = os.path.join(resource_root_dir, "resource")
if not os.path.exists(resource_dir):
logging.info("Extracting resources...")
with zipfile.ZipFile(resources_zip_file, "r") as zip_ref:
zip_ref.extractall(resource_root_dir)
with open(text_file, "r") as text_data:
texts = text_data.readlines()
logging.info("Converting text to symbols...")
symbols_lst = text_to_symbols(texts, resource_dir, "F7")
symbols_file = os.path.join(output_dir, "raw_metafile.txt")
with open(symbols_file, "w") as symbol_data:
for symbol in symbols_lst:
symbol_data.write(symbol)
logging.info("Processing done.")
# Generate BERT Text metafile
# TODO: train/valid ratio setting
gen_metafile(output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Dataset preprocessor")
parser.add_argument("--text_file", type=str, required=True)
parser.add_argument("--resources_zip_file", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
process_mit_style_data(
args.text_file,
args.resources_zip_file,
args.output_dir,
)
import torch
import torch.nn.functional as F
from kantts.utils.audio_torch import stft, MelSpectrogram
from kantts.models.utils import get_mask_from_lengths
class MelReconLoss(torch.nn.Module):
def __init__(self, loss_type="mae"):
super(MelReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == "mae":
self.criterion = torch.nn.L1Loss(reduction="none")
elif loss_type == "mse":
self.criterion = torch.nn.MSELoss(reduction="none")
else:
raise ValueError("Unknown loss type: {}".format(loss_type))
def forward(self, output_lengths, mel_targets, dec_outputs, postnet_outputs=None):
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1)
)
output_masks = ~output_masks
valid_outputs = output_masks.sum()
mel_loss_ = torch.sum(
self.criterion(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)
) / (valid_outputs * mel_targets.size(-1))
if postnet_outputs is not None:
mel_loss = torch.sum(
self.criterion(mel_targets, postnet_outputs)
* output_masks.unsqueeze(-1)
) / (valid_outputs * mel_targets.size(-1))
else:
mel_loss = 0.0
return mel_loss_, mel_loss
class ProsodyReconLoss(torch.nn.Module):
def __init__(self, loss_type="mae"):
super(ProsodyReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == "mae":
self.criterion = torch.nn.L1Loss(reduction="none")
elif loss_type == "mse":
self.criterion = torch.nn.MSELoss(reduction="none")
else:
raise ValueError("Unknown loss type: {}".format(loss_type))
def forward(
self,
input_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
):
input_masks = get_mask_from_lengths(
input_lengths, max_len=duration_targets.size(1)
)
input_masks = ~input_masks
valid_inputs = input_masks.sum()
dur_loss = (
torch.sum(
self.criterion(
torch.log(duration_targets.float() + 1), log_duration_predictions
)
* input_masks
)
/ valid_inputs
)
pitch_loss = (
torch.sum(self.criterion(pitch_targets, pitch_predictions) * input_masks)
/ valid_inputs
)
energy_loss = (
torch.sum(self.criterion(energy_targets, energy_predictions) * input_masks)
/ valid_inputs
)
return dur_loss, pitch_loss, energy_loss
class FpCELoss(torch.nn.Module):
def __init__(self, loss_type="ce", weight=[1, 4, 4, 8]):
super(FpCELoss, self).__init__()
self.loss_type = loss_type
weight_ce = torch.FloatTensor(weight).cuda()
self.criterion = torch.nn.CrossEntropyLoss(weight=weight_ce, reduction="none")
def forward(self, input_lengths, fp_pd, fp_label):
input_masks = get_mask_from_lengths(input_lengths, max_len=fp_label.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
fp_loss = (
torch.sum(self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
/ valid_inputs
)
return fp_loss
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse",
):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Args:
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs.
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse",
):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x):
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x):
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x):
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers=True,
average_by_discriminators=True,
):
"""Initialize FeatureMatchLoss module."""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
def forward(self, feats_hat, feats):
"""Calcualate feature matching loss.
Args:
feats_hat (list): List of list of discriminator outputs
calcuated from generater outputs.
feats (list): List of list of discriminator outputs
calcuated from groundtruth.
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base,
)
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(
self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"
):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
# NOTE(kan-bayashi): Use register_buffer to fix #223
self.register_buffer("window", getattr(torch, window)(win_length))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window",
):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
class SeqCELoss(torch.nn.Module):
def __init__(self, loss_type="ce"):
super(SeqCELoss, self).__init__()
self.loss_type = loss_type
self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self, logits, targets, masks):
loss = self.criterion(
logits.contiguous().view(-1, logits.size(-1)), targets.contiguous().view(-1)
)
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
masks = masks.contiguous().view(-1)
loss = (loss * masks).sum() / masks.sum()
err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
return loss, err
class AttentionBinarizationLoss(torch.nn.Module):
def __init__(self, start_epoch=0, warmup_epoch=100):
super(AttentionBinarizationLoss, self).__init__()
self.start_epoch = start_epoch
self.warmup_epoch = warmup_epoch
def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
log_sum = torch.log(
torch.clamp(soft_attention[hard_attention == 1], min=eps)
).sum()
kl_loss = -log_sum / hard_attention.sum()
if epoch < self.start_epoch:
warmup_ratio = 0
else:
warmup_ratio = min(1.0, (epoch - self.start_epoch) / self.warmup_epoch)
return kl_loss * warmup_ratio
class AttentionCTCLoss(torch.nn.Module):
def __init__(self, blank_logprob=-1):
super(AttentionCTCLoss, self).__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.blank_logprob = blank_logprob
self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = F.pad(
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
)
cost_total = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
curr_logprob = curr_logprob[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
ctc_cost = self.CTCLoss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
cost_total += ctc_cost
cost = cost_total / attn_logprob.shape[0]
return cost
# TODO: create a mapping for new loss functions
loss_dict = {
"generator_adv_loss": GeneratorAdversarialLoss,
"discriminator_adv_loss": DiscriminatorAdversarialLoss,
"stft_loss": MultiResolutionSTFTLoss,
"mel_loss": MelSpectrogramLoss,
"subband_stft_loss": MultiResolutionSTFTLoss,
"feat_match_loss": FeatureMatchLoss,
"MelReconLoss": MelReconLoss,
"ProsodyReconLoss": ProsodyReconLoss,
"SeqCELoss": SeqCELoss,
"AttentionBinarizationLoss": AttentionBinarizationLoss,
"AttentionCTCLoss": AttentionCTCLoss,
"FpCELoss": FpCELoss,
}
def criterion_builder(config, device="cpu"):
"""Criterion builder.
Args:
config (dict): Config dictionary.
Returns:
criterion (dict): Loss dictionary
"""
criterion = {}
for key, value in config["Loss"].items():
if key in loss_dict:
if value["enable"]:
criterion[key] = loss_dict[key](**value.get("params", {})).to(device)
setattr(criterion[key], "weights", value.get("weights", 1.0))
else:
raise NotImplementedError("{} is not implemented".format(key))
return criterion
from torch.optim.lr_scheduler import * # NOQA
from torch.optim.lr_scheduler import _LRScheduler # NOQA
"""Noam Scheduler."""
class FindLR(_LRScheduler):
"""
inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
"""
def __init__(self, optimizer, max_steps, max_lr=10):
self.max_steps = max_steps
self.max_lr = max_lr
super().__init__(optimizer)
def get_lr(self):
return [
base_lr
* ((self.max_lr / base_lr) ** (self.last_epoch / (self.max_steps - 1)))
for base_lr in self.base_lrs
]
class NoamLR(_LRScheduler):
"""
Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
to the inverse square root of the step number, scaled by the inverse square root of the
dimensionality of the model. Time will tell if this is just madness or it's actually important.
Parameters
----------
warmup_steps: ``int``, required.
The number of steps to linearly increase the learning rate.
"""
def __init__(self, optimizer, warmup_steps):
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
last_epoch = max(1, self.last_epoch)
scale = self.warmup_steps ** 0.5 * min(
last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)
)
return [base_lr * scale for base_lr in self.base_lrs]
import os
import sys
import logging
import torch
from collections import defaultdict
from tensorboardX import SummaryWriter
from tqdm import tqdm
import soundfile as sf
import numpy as np
from kantts.utils.plot import plot_spectrogram, plot_alignment
def traversal_dict(d, func):
if not isinstance(d, dict):
logging.error("Not a dict: {}".format(d))
return
for k, v in d.items():
if isinstance(v, dict):
traversal_dict(v, func)
else:
func(k, v)
def distributed_init():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("RANK", 0))
distributed = world_size > 1
device = torch.device("cuda", local_rank)
if distributed:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
logging.info(
"Distributed training, global world size: {}, local world size: {}, global rank: {}, local rank: {}".format(
world_size,
torch.cuda.device_count(),
torch.distributed.get_rank(),
local_rank,
)
)
logging.info("nccl backend: {}".format(torch.distributed.is_nccl_available()))
logging.info("mpi backend: {}".format(torch.distributed.is_mpi_available()))
device_ids = list(range(torch.cuda.device_count()))
logging.info(
"[{}] rank = {}, world_size = {}, n_gpus = {}, device_ids = {}".format(
os.getpid(),
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
torch.cuda.device_count(),
device_ids,
)
)
return distributed, device, local_rank, world_size
class Trainer(object):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.criterion = criterion
self.device = device
self.sampler = sampler
self.train_loader = train_loader
self.valid_loader = valid_loader
self.max_epochs = max_epochs
self.steps = 1
self.epoch = 0
self.save_dir = save_dir
self.save_interval = save_interval
self.valid_interval = valid_interval
self.log_interval = log_interval
self.grad_clip = grad_clip
self.total_train_loss = defaultdict(float)
self.total_eval_loss = defaultdict(float)
self.config = config
self.distributed = self.config.get("distributed", False)
self.rank = self.config.get("rank", 0)
self.log_dir = os.path.join(save_dir, "log")
self.ckpt_dir = os.path.join(save_dir, "ckpt")
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(self.ckpt_dir, exist_ok=True)
self.writer = SummaryWriter(self.log_dir)
if max_epochs is None:
self.max_epochs = sys.maxsize
else:
self.max_epochs = int(max_epochs)
if max_steps is None:
self.max_steps = sys.maxsize
else:
self.max_steps = int(max_steps)
self.finish_training = False
def set_model_state(self, state="train"):
if state == "train":
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].train()
else:
self.model.train()
elif state == "eval":
if isinstance(self.model, dict):
for key in self.model.keys():
self.model[key].eval()
else:
self.model.eval()
else:
raise ValueError("state must be either 'train' or 'eval'.")
def write_to_tensorboard(self, loss):
"""Write to tensorboard."""
for key, value in loss.items():
self.writer.add_scalar(key, value, self.steps)
# FIXME: an example for simple feedforward model
def save_checkpoint(self, checkpoint_path):
state_dict = {
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"steps": self.steps,
"model": self.model.state_dict(),
}
# TODO: distributed training
if not os.path.exists(checkpoint_path):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path)
self.model.load_state_dict(state_dict["model"], strict=strict)
if restore_training_state:
if "optimizer" in state_dict:
self.optimizer["KanTtsSAMBERT"].load_state_dict(state_dict["optimizer"])
if "scheduler" in state_dict:
self.scheduler["KanTtsSAMBERT"].load_state_dict(state_dict["scheduler"])
if "steps" in state_dict:
self.steps = state_dict["steps"]
# TODO
def check_save_interval(self):
if self.ckpt_dir is not None and (self.steps) % self.save_interval == 0:
self.save_checkpoint(
os.path.join(self.ckpt_dir, "checkpoint_{}.pth".format(self.steps))
)
logging.info("Checkpoint saved at step {}".format(self.steps))
def check_log_interval(self):
if self.writer is not None and (self.steps) % self.log_interval == 0:
for key in self.total_train_loss.keys():
self.total_train_loss[key] /= self.config["log_interval_steps"]
logging.info(
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
)
self.write_to_tensorboard(self.total_train_loss)
self.total_train_loss = defaultdict(float)
def log_learning_rate(key, sche):
logging.info("{} learning rate: {:.6f}".format(key, sche.get_lr()[0]))
self.write_to_tensorboard({"{}_lr".format(key): sche.get_lr()[0]})
traversal_dict(self.scheduler, log_learning_rate)
def check_eval_interval(self):
if self.valid_interval > 0 and (self.steps) % self.valid_interval == 0:
self.eval_epoch()
def check_stop_training(self):
if self.steps >= self.max_steps or self.epoch >= self.max_epochs:
self.finish_training = True
def train(self):
self.set_model_state("train")
while True:
self.train_epoch()
self.epoch += 1
self.check_stop_training()
if self.finish_training:
break
def train_epoch(self):
for batch in tqdm(self.train_loader):
self.train_step(batch)
if self.rank == 0:
self.check_eval_interval()
self.check_save_interval()
self.check_log_interval()
self.steps += 1
self.check_stop_training()
if self.finish_training:
break
logging.info("Epoch {} finished".format(self.epoch))
if self.distributed:
self.sampler["train"].set_epoch(self.epoch)
# TODO: implement train_step() for specific model
def train_step(self, batch):
data, target = batch
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
# TODO: implement eval_epoch() for specific model
@torch.no_grad()
def eval_step(self, batch):
pass
def eval_epoch(self):
logging.info(f"(Epoch: {self.epoch}) Start evaluation.")
# change mode
self.set_model_state("eval")
self.total_eval_loss = defaultdict(float)
# TODO: save some intermidiate results
rand_idx = np.random.randint(0, len(self.valid_loader))
idx = 0
logging.info("Valid data size: {}".format(len(self.valid_loader)))
for batch in tqdm(self.valid_loader):
self.eval_step(batch)
if idx == rand_idx:
logging.info(
f"(Epoch: {self.epoch}) Random batch: {idx}, generating image."
)
self.genearete_and_save_intermediate_result(batch)
idx += 1
for key in self.total_eval_loss.keys():
self.total_eval_loss[key] /= idx + 1
logging.info(
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
)
self.write_to_tensorboard(self.total_eval_loss)
logging.info("Epoch {} evaluation finished".format(self.epoch))
self.set_model_state("train")
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
pass
class GAN_Trainer(Trainer):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
super().__init__(
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs,
max_steps,
save_dir,
save_interval,
valid_interval,
log_interval,
grad_clip,
)
def set_model_state(self, state="train"):
if state == "train":
if isinstance(self.model, dict):
self.model["generator"].train()
for key in self.model["discriminator"].keys():
self.model["discriminator"][key].train()
else:
self.model.train()
elif state == "eval":
if isinstance(self.model, dict):
self.model["generator"].eval()
for key in self.model["discriminator"].keys():
self.model["discriminator"][key].eval()
else:
self.model.eval()
else:
raise ValueError("state must be either 'train' or 'eval'.")
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
"""Generate and save intermediate result."""
# delayed import to avoid error related backend error
import matplotlib.pyplot as plt
# generate
y_batch, x_batch = batch
y_batch, x_batch = y_batch.to(self.device), x_batch.to(self.device)
y_batch_ = self.model["generator"](x_batch)
if self.model.get("pqmf", None):
y_mb_ = y_batch_
y_batch_ = self.model["pqmf"].synthesis(y_mb_)
# check directory
dirname = os.path.join(self.log_dir, f"predictions/{self.steps}steps")
if not os.path.exists(dirname):
os.makedirs(dirname)
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1):
# convert to ndarray
y, y_ = y.view(-1).cpu().numpy(), y_.view(-1).cpu().numpy()
# plot figure and save it
figname = os.path.join(dirname, f"{idx}.png")
plt.subplot(2, 1, 1)
plt.plot(y)
plt.title("groundtruth speech")
plt.subplot(2, 1, 2)
plt.plot(y_)
plt.title(f"generated speech @ {self.steps} steps")
plt.tight_layout()
plt.savefig(figname)
plt.close()
# save as wavfile
y = np.clip(y, -1, 1)
y_ = np.clip(y_, -1, 1)
sf.write(
figname.replace(".png", "_ref.wav"),
y,
self.config["audio_config"]["sampling_rate"],
"PCM_16",
)
sf.write(
figname.replace(".png", "_gen.wav"),
y_,
self.config["audio_config"]["sampling_rate"],
"PCM_16",
)
if idx >= self.config["num_save_intermediate_results"]:
break
@torch.no_grad()
def eval_step(self, batch):
y, x = batch
y, x = y.to(self.device), x.to(self.device)
y_ = self.model["generator"](x)
# reconstruct the signal from multi-band signal
if self.model.get("pqmf", None):
y_mb_ = y_
y_ = self.model["pqmf"].synthesis(y_mb_)
aux_loss = 0.0
# multi-resolution sfft loss
if self.criterion.get("stft_loss", None):
sc_loss, mag_loss = self.criterion["stft_loss"](y_, y)
aux_loss += (sc_loss + mag_loss) * self.criterion["stft_loss"].weights
self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item()
# subband multi-resolution stft loss
if self.criterion.get("subband_stft_loss", None):
aux_loss *= 0.5 # for balancing with subband stft loss
y_mb = self.model["pqmf"].analysis(y)
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
self.total_eval_loss[
"eval/sub_spectral_convergence_loss"
] += sub_sc_loss.item()
self.total_eval_loss[
"eval/sub_log_stft_magnitude_loss"
] += sub_mag_loss.item()
aux_loss += (
0.5 * (sub_sc_loss + sub_mag_loss) * self.criterion["sub_stft"].weights
)
# mel spectrogram loss
if self.criterion.get("mel_loss", None):
mel_loss = self.criterion["mel_loss"](y_, y)
aux_loss += mel_loss * self.criterion["mel_loss"].weights
self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
fmap_lst_ = []
adv_loss = 0.0
# adversiral loss
for discriminator in self.model["discriminator"].keys():
p_, fmap_ = self.model["discriminator"][discriminator](y_)
fmap_lst_.append(fmap_)
adv_loss += (
self.criterion["generator_adv_loss"](p_)
* self.criterion["generator_adv_loss"].weights
)
gen_loss = aux_loss + adv_loss
if self.criterion.get("feat_match_loss", None):
fmap_lst = []
# no need to track gradients
# TODO: implement feature matching loss
for discriminator in self.model["discriminator"].keys():
with torch.no_grad():
p, fmap = self.model["discriminator"][discriminator](y)
fmap_lst.append(fmap)
fm_loss = 0.0
for fmap_, fmap in zip(fmap_lst, fmap_lst_):
fm_loss += self.criterion["feat_match_loss"](fmap_, fmap)
self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item()
gen_loss += fm_loss * self.criterion["feat_match_loss"].weights
dis_loss = 0.0
for discriminator in self.model["discriminator"].keys():
p, fmap = self.model["discriminator"][discriminator](y)
p_, fmap_ = self.model["discriminator"][discriminator](y_.detach())
real_loss, fake_loss = self.criterion["discriminator_adv_loss"](p_, p)
dis_loss += real_loss + fake_loss
self.total_eval_loss["eval/real_loss"] += real_loss.item()
self.total_eval_loss["eval/fake_loss"] += fake_loss.item()
self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()
self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
def train_step(self, batch):
y, x = batch
y, x = y.to(self.device), x.to(self.device)
if self.steps >= self.config.get("generator_train_start_steps", 0):
y_ = self.model["generator"](x)
# reconstruct the signal from multi-band signal
if self.model.get("pqmf", None):
y_mb_ = y_
y_ = self.model["pqmf"].synthesis(y_mb_)
# initialize
gen_loss = 0.0
# multi-resolution sfft loss
if self.criterion.get("stft_loss", None):
sc_loss, mag_loss = self.criterion["stft_loss"](y_, y)
gen_loss += (sc_loss + mag_loss) * self.criterion["stft_loss"].weights
self.total_train_loss[
"train/spectral_convergence_loss"
] += sc_loss.item()
self.total_train_loss[
"train/log_stft_magnitude_loss"
] += mag_loss.item()
# subband multi-resolution stft loss
if self.criterion.get("subband_stft_loss", None):
gen_loss *= 0.5 # for balancing with subband stft loss
y_mb = self.model["pqmf"].analysis(y)
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
self.total_train_loss[
"train/sub_spectral_convergence_loss"
] += sub_sc_loss.item()
self.total_train_loss[
"train/sub_log_stft_magnitude_loss"
] += sub_mag_loss.item()
# mel spectrogram loss
if self.criterion.get("mel_loss", None):
mel_loss = self.criterion["mel_loss"](y_, y)
gen_loss += mel_loss * self.criterion["mel_loss"].weights
self.total_train_loss["train/mel_loss"] += mel_loss.item()
# adversarial loss
if self.steps > self.config["discriminator_train_start_steps"]:
adv_loss = 0.0
fmap_lst_ = []
for discriminator in self.model["discriminator"].keys():
p_, fmap_ = self.model["discriminator"][discriminator](y_)
fmap_lst_.append(fmap_)
adv_loss += self.criterion["generator_adv_loss"](p_)
self.total_train_loss["train/adversarial_loss"] += adv_loss.item()
gen_loss += adv_loss * self.criterion["generator_adv_loss"].weights
# feature matching loss
if self.criterion.get("feat_match_loss", None):
fmap_lst = []
# no need to track gradients
# TODO: implement feature matching loss
for discriminator in self.model["discriminator"].keys():
with torch.no_grad():
p, fmap = self.model["discriminator"][discriminator](y)
fmap_lst.append(fmap)
fm_loss = 0.0
for fmap_, fmap in zip(fmap_lst, fmap_lst_):
fm_loss += self.criterion["feat_match_loss"](fmap_, fmap)
self.total_train_loss[
"train/feature_matching_loss"
] += fm_loss.item()
gen_loss += fm_loss * self.criterion["feat_match_loss"].weights
self.total_train_loss["train/generator_loss"] += gen_loss.item()
# update generator
self.optimizer["generator"].zero_grad()
gen_loss.backward()
if self.config["generator_grad_norm"] > 0:
torch.nn.utils.clip_grad_norm_(
self.model["generator"].parameters(),
self.config["generator_grad_norm"],
)
self.optimizer["generator"].step()
self.scheduler["generator"].step()
# update discriminator
if self.steps > self.config["discriminator_train_start_steps"]:
# re-compute y_ which leads better quality
with torch.no_grad():
y_ = self.model["generator"](x)
if self.model.get("pqmf", None):
y_ = self.model["pqmf"].synthesis(y_)
# discriminator loss
dis_loss = 0.0
for discriminator in self.model["discriminator"].keys():
p, fmap = self.model["discriminator"][discriminator](y)
p_, fmap_ = self.model["discriminator"][discriminator](y_.detach())
real_loss, fake_loss = self.criterion["discriminator_adv_loss"](p_, p)
dis_loss += real_loss + fake_loss
self.total_train_loss["train/real_loss"] += real_loss.item()
self.total_train_loss["train/fake_loss"] += fake_loss.item()
self.total_train_loss["train/discriminator_loss"] += dis_loss.item()
# update discriminator
for key in self.optimizer["discriminator"].keys():
self.optimizer["discriminator"][key].zero_grad()
dis_loss.backward()
if self.config["discriminator_grad_norm"] > 0:
torch.nn.utils.clip_grad_norm_(
self.model["discriminator"].parameters(),
self.config["discriminator_grad_norm"],
)
for key in self.optimizer["discriminator"].keys():
self.optimizer["discriminator"][key].step()
for key in self.scheduler["discriminator"].keys():
self.scheduler["discriminator"][key].step()
def save_checkpoint(self, checkpoint_path):
state_dict = {
"optimizer": {
"generator": self.optimizer["generator"].state_dict(),
"discriminator": {},
},
"scheduler": {
"generator": self.scheduler["generator"].state_dict(),
"discriminator": {},
},
"steps": self.steps,
}
for model_name in self.optimizer["discriminator"].keys():
state_dict["optimizer"]["discriminator"][model_name] = self.optimizer[
"discriminator"
][model_name].state_dict()
for model_name in self.scheduler["discriminator"].keys():
state_dict["scheduler"]["discriminator"][model_name] = self.scheduler[
"discriminator"
][model_name].state_dict()
if not self.distributed:
model_state = self.model["generator"].state_dict()
else:
model_state = self.model["generator"].module.state_dict()
state_dict["model"] = {
"generator": model_state,
"discriminator": {},
}
for model_name in self.model["discriminator"].keys():
if not self.distributed:
model_state = self.model["discriminator"][model_name].state_dict()
else:
model_state = self.model["discriminator"][
model_name
].module.state_dict()
state_dict["model"]["discriminator"][model_name] = model_state
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path, map_location="cpu")
if not self.distributed:
self.model["generator"].load_state_dict(
state_dict["model"]["generator"], strict=strict
)
else:
self.model["generator"].module.load_state_dict(
state_dict["model"]["generator"], strict=strict
)
for model_name in state_dict["model"]["discriminator"]:
if not self.distributed:
self.model["discriminator"][model_name].load_state_dict(
state_dict["model"]["discriminator"][model_name], strict=strict
)
else:
self.model["discriminator"][model_name].module.load_state_dict(
state_dict["model"]["discriminator"][model_name], strict=strict
)
if restore_training_state:
if "steps" in state_dict:
self.steps = state_dict["steps"]
if "optimizer" in state_dict:
self.optimizer["generator"].load_state_dict(
state_dict["optimizer"]["generator"]
)
for model_name in state_dict["optimizer"]["discriminator"].keys():
self.optimizer["discriminator"][model_name].load_state_dict(
state_dict["optimizer"]["discriminator"][model_name]
)
if "scheduler" in state_dict:
for model_name in state_dict["scheduler"]["discriminator"].keys():
self.scheduler["discriminator"][model_name].load_state_dict(
state_dict["scheduler"]["discriminator"][model_name]
)
self.scheduler["generator"].load_state_dict(
state_dict["scheduler"]["generator"]
)
class Sambert_Trainer(Trainer):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
super().__init__(
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs,
max_steps,
save_dir,
save_interval,
valid_interval,
log_interval,
grad_clip,
)
self.with_MAS = config["Model"]["KanTtsSAMBERT"]["params"].get("MAS", False)
self.fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
inputs_emotion = batch["input_emotions"].to(self.device)
inputs_speaker = batch["input_speakers"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
mel_targets = batch["mel_targets"].to(self.device)
# generate mel spectrograms
res = self.model["KanTtsSAMBERT"](
inputs_ling[0:1],
inputs_emotion[0:1],
inputs_speaker[0:1],
valid_input_lengths[0:1],
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
enc_slf_attn_lst = res["enc_slf_attn_lst"]
pnca_x_attn_lst = res["pnca_x_attn_lst"]
pnca_h_attn_lst = res["pnca_h_attn_lst"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
dirname = os.path.join(self.log_dir, f"predictions/{self.steps}steps")
if not os.path.exists(dirname):
os.makedirs(dirname)
for layer_id, slf_attn in enumerate(enc_slf_attn_lst):
for head_id in range(
self.config["Model"]["KanTtsSAMBERT"]["params"]["encoder_num_heads"]
):
fig = plot_alignment(
slf_attn[
head_id, : valid_input_lengths[0], : valid_input_lengths[0]
]
.cpu()
.numpy(),
info="valid_len_{}".format(valid_input_lengths[0].item()),
)
fig.savefig(
os.path.join(
dirname,
"enc_slf_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
zip(pnca_x_attn_lst, pnca_h_attn_lst)
):
for head_id in range(
self.config["Model"]["KanTtsSAMBERT"]["params"]["decoder_num_heads"]
):
fig = plot_alignment(
pnca_x_attn[head_id, :, :].cpu().numpy(),
info="x_band_width_{}".format(x_band_width),
)
fig.savefig(
os.path.join(
dirname,
"pnca_x_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
fig = plot_alignment(
pnca_h_attn[head_id, :, :].cpu().numpy(),
info="h_band_width_{}".format(h_band_width),
)
fig.savefig(
os.path.join(
dirname,
"pnca_h_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
target_mel = mel_targets[0].cpu().numpy()
coarse_mel = dec_outputs.squeeze(0).cpu().numpy()
output_mel = postnet_outputs.squeeze(0).cpu().numpy()
np.save(os.path.join(dirname, "coarse_mel.npy"), coarse_mel)
np.save(os.path.join(dirname, "output_mel.npy"), output_mel)
np.save(os.path.join(dirname, "target_mel.npy"), target_mel)
fig = plot_spectrogram(coarse_mel.T)
fig.savefig(os.path.join(dirname, "mel_dec_outputs"))
fig = plot_spectrogram(output_mel.T)
fig.savefig(os.path.join(dirname, "mel_postnet_outputs"))
@torch.no_grad()
def eval_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
inputs_emotion = batch["input_emotions"].to(self.device)
inputs_speaker = batch["input_speakers"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
valid_output_lengths = batch["valid_output_lengths"].to(self.device)
mel_targets = batch["mel_targets"].to(self.device)
durations = (
batch["durations"].to(self.device)
if batch["durations"] is not None
else None
)
pitch_contours = batch["pitch_contours"].to(self.device)
energy_contours = batch["energy_contours"].to(self.device)
attn_priors = (
batch["attn_priors"].to(self.device)
if batch["attn_priors"] is not None
else None
)
fp_label = None
if self.fp_enable:
fp_label = batch["fp_label"].to(self.device)
# generate mel spectrograms
res = self.model["KanTtsSAMBERT"](
inputs_ling,
inputs_emotion,
inputs_speaker,
valid_input_lengths,
output_lengths=valid_output_lengths,
mel_targets=mel_targets,
duration_targets=durations,
pitch_targets=pitch_contours,
energy_targets=energy_contours,
attn_priors=attn_priors,
fp_label=fp_label,
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
log_duration_predictions = res["log_duration_predictions"]
pitch_predictions = res["pitch_predictions"]
energy_predictions = res["energy_predictions"]
duration_targets = res["duration_targets"]
pitch_targets = res["pitch_targets"]
energy_targets = res["energy_targets"]
fp_predictions = res["fp_predictions"]
valid_inter_lengths = res["valid_inter_lengths"]
mel_loss_, mel_loss = self.criterion["MelReconLoss"](
valid_output_lengths, mel_targets, dec_outputs, postnet_outputs
)
dur_loss, pitch_loss, energy_loss = self.criterion["ProsodyReconLoss"](
valid_inter_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
)
loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss
if self.fp_enable:
fp_loss = self.criterion["FpCELoss"](
valid_input_lengths, fp_predictions, fp_label
)
loss_total = loss_total + fp_loss
if self.with_MAS:
attn_soft = res["attn_soft"]
attn_hard = res["attn_hard"]
attn_logprob = res["attn_logprob"]
attn_ctc_loss = self.criterion["AttentionCTCLoss"](
attn_logprob, valid_input_lengths, valid_output_lengths
)
attn_kl_loss = self.criterion["AttentionBinarizationLoss"](
self.epoch, attn_hard, attn_soft
)
loss_total += attn_ctc_loss + attn_kl_loss
self.total_eval_loss["eval/attn_ctc_loss"] += attn_ctc_loss.item()
self.total_eval_loss["eval/attn_kl_loss"] += attn_kl_loss.item()
self.total_eval_loss["eval/TotalLoss"] += loss_total.item()
self.total_eval_loss["eval/mel_loss_"] += mel_loss_.item()
self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
self.total_eval_loss["eval/dur_loss"] += dur_loss.item()
self.total_eval_loss["eval/pitch_loss"] += pitch_loss.item()
self.total_eval_loss["eval/energy_loss"] += energy_loss.item()
if self.fp_enable:
self.total_eval_loss["eval/fp_loss"] += fp_loss.item()
self.total_eval_loss["eval/batch_size"] += mel_targets.size(0)
self.total_eval_loss["eval/x_band_width"] += x_band_width
self.total_eval_loss["eval/h_band_width"] += h_band_width
def train_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
inputs_emotion = batch["input_emotions"].to(self.device)
inputs_speaker = batch["input_speakers"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
valid_output_lengths = batch["valid_output_lengths"].to(self.device)
mel_targets = batch["mel_targets"].to(self.device)
durations = (
batch["durations"].to(self.device)
if batch["durations"] is not None
else None
)
pitch_contours = batch["pitch_contours"].to(self.device)
energy_contours = batch["energy_contours"].to(self.device)
attn_priors = (
batch["attn_priors"].to(self.device)
if batch["attn_priors"] is not None
else None
)
fp_label = None
if self.fp_enable:
fp_label = batch["fp_label"].to(self.device)
# generate mel spectrograms
res = self.model["KanTtsSAMBERT"](
inputs_ling,
inputs_emotion,
inputs_speaker,
valid_input_lengths,
output_lengths=valid_output_lengths,
mel_targets=mel_targets,
duration_targets=durations,
pitch_targets=pitch_contours,
energy_targets=energy_contours,
attn_priors=attn_priors,
fp_label=fp_label,
)
x_band_width = res["x_band_width"]
h_band_width = res["h_band_width"]
dec_outputs = res["dec_outputs"]
postnet_outputs = res["postnet_outputs"]
log_duration_predictions = res["log_duration_predictions"]
pitch_predictions = res["pitch_predictions"]
energy_predictions = res["energy_predictions"]
duration_targets = res["duration_targets"]
pitch_targets = res["pitch_targets"]
energy_targets = res["energy_targets"]
fp_predictions = res["fp_predictions"]
valid_inter_lengths = res["valid_inter_lengths"]
mel_loss_, mel_loss = self.criterion["MelReconLoss"](
valid_output_lengths, mel_targets, dec_outputs, postnet_outputs
)
dur_loss, pitch_loss, energy_loss = self.criterion["ProsodyReconLoss"](
valid_inter_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
)
loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss
if self.fp_enable:
fp_loss = self.criterion["FpCELoss"](
valid_input_lengths, fp_predictions, fp_label
)
loss_total = loss_total + fp_loss
if self.with_MAS:
attn_soft = res["attn_soft"]
attn_hard = res["attn_hard"]
attn_logprob = res["attn_logprob"]
attn_ctc_loss = self.criterion["AttentionCTCLoss"](
attn_logprob, valid_input_lengths, valid_output_lengths
)
attn_kl_loss = self.criterion["AttentionBinarizationLoss"](
self.epoch, attn_hard, attn_soft
)
loss_total += attn_ctc_loss + attn_kl_loss
self.total_train_loss["train/attn_ctc_loss"] += attn_ctc_loss.item()
self.total_train_loss["train/attn_kl_loss"] += attn_kl_loss.item()
self.total_train_loss["train/TotalLoss"] += loss_total.item()
self.total_train_loss["train/mel_loss_"] += mel_loss_.item()
self.total_train_loss["train/mel_loss"] += mel_loss.item()
self.total_train_loss["train/dur_loss"] += dur_loss.item()
self.total_train_loss["train/pitch_loss"] += pitch_loss.item()
self.total_train_loss["train/energy_loss"] += energy_loss.item()
if self.fp_enable:
self.total_train_loss["train/fp_loss"] += fp_loss.item()
self.total_train_loss["train/batch_size"] += mel_targets.size(0)
self.total_train_loss["train/x_band_width"] += x_band_width
self.total_train_loss["train/h_band_width"] += h_band_width
self.optimizer["KanTtsSAMBERT"].zero_grad()
loss_total.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
self.model["KanTtsSAMBERT"].parameters(), self.grad_clip
)
self.optimizer["KanTtsSAMBERT"].step()
self.scheduler["KanTtsSAMBERT"].step()
def save_checkpoint(self, checkpoint_path):
if not self.distributed:
model_state = self.model["KanTtsSAMBERT"].state_dict()
else:
model_state = self.model["KanTtsSAMBERT"].module.state_dict()
state_dict = {
"optimizer": self.optimizer["KanTtsSAMBERT"].state_dict(),
"scheduler": self.scheduler["KanTtsSAMBERT"].state_dict(),
"steps": self.steps,
"model": model_state,
}
if not os.path.exists(checkpoint_path):
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path)
if not self.distributed:
self.model["KanTtsSAMBERT"].load_state_dict(
state_dict["model"], strict=strict
)
else:
self.model["KanTtsSAMBERT"].module.load_state_dict(
state_dict["model"], strict=strict
)
if restore_training_state:
if "optimizer" in state_dict:
self.optimizer["KanTtsSAMBERT"].load_state_dict(state_dict["optimizer"])
if "scheduler" in state_dict:
self.scheduler["KanTtsSAMBERT"].load_state_dict(state_dict["scheduler"])
if "steps" in state_dict:
self.steps = state_dict["steps"]
class Textsy_BERT_Trainer(Trainer):
def __init__(
self,
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs=None,
max_steps=None,
save_dir=None,
save_interval=1,
valid_interval=1,
log_interval=10,
grad_clip=None,
):
super().__init__(
config,
model,
optimizer,
scheduler,
criterion,
device,
sampler,
train_loader,
valid_loader,
max_epochs,
max_steps,
save_dir,
save_interval,
valid_interval,
log_interval,
grad_clip,
)
@torch.no_grad()
def genearete_and_save_intermediate_result(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
bert_masks = batch["bert_masks"].to(self.device)
targets = batch["targets"].to(self.device)
res = self.model["KanTtsTextsyBERT"](
inputs_ling[0:1],
valid_input_lengths[0:1],
)
logits = res["logits"]
enc_slf_attn_lst = res["enc_slf_attn_lst"]
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
dirname = os.path.join(self.log_dir, f"predictions/{self.steps}steps")
if not os.path.exists(dirname):
os.makedirs(dirname)
for layer_id, slf_attn in enumerate(enc_slf_attn_lst):
for head_id in range(
self.config["Model"]["KanTtsTextsyBERT"]["params"]["encoder_num_heads"]
):
fig = plot_alignment(
slf_attn[
head_id, : valid_input_lengths[0], : valid_input_lengths[0]
]
.cpu()
.numpy(),
info="valid_len_{}".format(valid_input_lengths[0].item()),
)
fig.savefig(
os.path.join(
dirname,
"enc_slf_attn_dev_layer{}_head{}".format(layer_id, head_id),
)
)
target = targets[0].cpu().numpy()
bert_mask = bert_masks[0].cpu().numpy()
pred = preds.cpu().numpy()
np.save(os.path.join(dirname, "pred.npy"), pred)
np.save(os.path.join(dirname, "target.npy"), target)
np.save(os.path.join(dirname, "bert_mask.npy"), bert_mask)
@torch.no_grad()
def eval_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
bert_masks = batch["bert_masks"].to(self.device)
targets = batch["targets"].to(self.device)
res = self.model["KanTtsTextsyBERT"](
inputs_ling,
valid_input_lengths,
)
logits = res["logits"]
loss_total, err = self.criterion["SeqCELoss"](
logits,
targets,
bert_masks,
)
loss_total = loss_total / logits.size(-1)
self.total_eval_loss["eval/TotalLoss"] += loss_total.item()
self.total_eval_loss["eval/Error"] += err.item()
self.total_eval_loss["eval/batch_size"] += targets.size(0)
def train_step(self, batch):
inputs_ling = batch["input_lings"].to(self.device)
valid_input_lengths = batch["valid_input_lengths"].to(self.device)
bert_masks = batch["bert_masks"].to(self.device)
targets = batch["targets"].to(self.device)
res = self.model["KanTtsTextsyBERT"](
inputs_ling,
valid_input_lengths,
)
logits = res["logits"]
loss_total, err = self.criterion["SeqCELoss"](
logits,
targets,
bert_masks,
)
loss_total = loss_total / logits.size(-1)
self.optimizer["KanTtsTextsyBERT"].zero_grad()
loss_total.backward()
if self.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
self.model["KanTtsTextsyBERT"].parameters(), self.grad_clip
)
self.optimizer["KanTtsTextsyBERT"].step()
self.scheduler["KanTtsTextsyBERT"].step()
self.total_train_loss["train/TotalLoss"] += loss_total.item()
self.total_train_loss["train/Error"] += err.item()
self.total_train_loss["train/batch_size"] += targets.size(0)
def save_checkpoint(self, checkpoint_path):
if not self.distributed:
model_state = self.model["KanTtsTextsyBERT"].state_dict()
else:
model_state = self.model["KanTtsTextsyBERT"].module.state_dict()
state_dict = {
"optimizer": self.optimizer["KanTtsTextsyBERT"].state_dict(),
"scheduler": self.scheduler["KanTtsTextsyBERT"].state_dict(),
"steps": self.steps,
"model": model_state,
}
if not os.path.exists(checkpoint_path):
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self, checkpoint_path, restore_training_state=False, strict=True
):
state_dict = torch.load(checkpoint_path)
if not self.distributed:
self.model["KanTtsTextsyBERT"].load_state_dict(
state_dict["model"], strict=strict
)
else:
self.model["KanTtsTextsyBERT"].module.load_state_dict(
state_dict["model"], strict=strict
)
if restore_training_state:
self.optimizer["KanTtsTextsyBERT"].load_state_dict(state_dict["optimizer"])
self.scheduler["KanTtsTextsyBERT"].load_state_dict(state_dict["scheduler"])
self.steps = state_dict["steps"]
import torch
import librosa
from distutils.version import LooseVersion
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
if is_pytorch_17plus:
x_stft = torch.stft(
x, fft_size, hop_size, win_length, window, return_complex=False
)
else:
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.pow(10.0, x * 0.05) / C
def spectral_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
output = dynamic_range_compression_torch(magnitudes) - ref_level_db
if symmetric:
return torch.clamp(
2 * norm_abs_value * ((output - min_level_db) / (-min_level_db))
- norm_abs_value,
min=-norm_abs_value,
max=norm_abs_value,
)
else:
return torch.clamp(
norm_abs_value * ((output - min_level_db) / (-min_level_db)),
min=0.0,
max=norm_abs_value,
)
def spectral_de_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
if symmetric:
magnitudes = torch.clamp(magnitudes, min=-norm_abs_value, max=norm_abs_value)
magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
2 * norm_abs_value
) + min_level_db
else:
magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
magnitudes = (magnitudes) * (-min_level_db) / (norm_abs_value) + min_level_db
output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
return output
class MelSpectrogram(torch.nn.Module):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
pad_mode="constant",
):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
self.pad_mode = pad_mode
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax,
)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
self.stft_params = {
"n_fft": self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center": self.center,
"normalized": self.normalized,
"onesided": self.onesided,
"pad_mode": self.pad_mode,
}
if is_pytorch_17plus:
self.stft_params["return_complex"] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
x_mel = spectral_normalize_torch(x_mel)
# return self.log(x_mel).transpose(1, 2)
return x_mel.transpose(1, 2)
import ttsfrd
ENG_LANG_MAPPING = {
"PinYin": "zh-cn",
"English": "en-us",
"British": "en-gb",
"ZhHK": "hk_cantonese",
"Sichuan": "sichuan",
"Japanese": "japanese",
"WuuShangHai": "shanghai",
"Indonesian": "indonesian",
"Malay": "malay",
"Filipino": "filipino",
"Vietnamese": "vietnamese",
"Korean": "korean",
"Russian": "russian",
}
def text_to_mit_symbols(texts, resources_dir, speaker, lang="PinYin"):
fe = ttsfrd.TtsFrontendEngine()
fe.initialize(resources_dir)
fe.set_lang_type(ENG_LANG_MAPPING[lang])
symbols_lst = []
for idx, text in enumerate(texts):
text = text.strip()
res = fe.gen_tacotron_symbols(text)
res = res.replace("F7", speaker)
sentences = res.split("\n")
for sentence in sentences:
arr = sentence.split("\t")
# skip the empty line
if len(arr) != 2:
continue
sub_index, symbols = sentence.split("\t")
symbol_str = "{}_{}\t{}\n".format(idx, sub_index, symbols)
symbols_lst.append(symbol_str)
return symbols_lst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment