Commit 51782715 authored by liugh5's avatar liugh5
Browse files

update

parent 8b4e9acd
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 192
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 82
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
NSF: True
nsf_norm_type: global
nsf_f0_global_minimum: 30.0
nsf_f0_global_maximum: 730.0
SE: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: False
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1760101 # Number of training steps.
save_interval_steps: 100 # Interval steps to save checkpoint.
eval_interval_steps: 1000000000000 # Interval steps to evaluate the network.
log_interval_steps: 10 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 192
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 82
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
NSF: True
nsf_norm_type: global
nsf_f0_global_minimum: 30.0
nsf_f0_global_maximum: 730.0
SE: True
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: False
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 2500000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 1000000000000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sambert
Model:
#########################################################
# SAMBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsSAMBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
speaker_units: 32
emotion_units: 32
predictor_filter_size: 41
predictor_fsmn_num_layers: 3
predictor_num_memory_units: 128
predictor_ffn_inner_dim: 256
predictor_dropout: 0.1
predictor_shift: 0
predictor_lstm_units: 128
dur_pred_prenet_units: [128, 128]
dur_pred_lstm_units: 128
decoder_prenet_units: [256, 256]
decoder_num_layers: 12
decoder_num_heads: 8
decoder_num_units: 128
decoder_ffn_inner_dim: 1024
decoder_dropout: 0.1
decoder_attention_dropout: 0.1
decoder_relu_dropout: 0.1
outputs_per_step: 3
num_mels: 80
postnet_filter_size: 41
postnet_fsmn_num_layers: 4
postnet_num_memory_units: 256
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: False
optimizer:
type: Adam
params:
lr: 0.001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 4000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: xiaoyue
language: Sichuan
####################################################
# LOSS SETTING #
####################################################
Loss:
MelReconLoss:
enable: True
params:
loss_type: mae
ProsodyReconLoss:
enable: True
params:
loss_type: mae
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
model_type: sybert
Model:
#########################################################
# TextsyBERT NETWORK ARCHITECTURE SETTING #
#########################################################
KanTtsTextsyBERT:
params:
max_len: 800
embedding_dim: 512
encoder_num_layers: 8
encoder_num_heads: 8
encoder_num_units: 128
encoder_ffn_inner_dim: 1024
encoder_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_relu_dropout: 0.1
encoder_projection_units: 32
mask_ratio: 0.3
optimizer:
type: Adam
params:
lr: 0.0001
betas: [0.9, 0.98]
eps: 1.0e-9
weight_decay: 0.0
scheduler:
type: NoamLR
params:
warmup_steps: 10000
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: F7
####################################################
# LOSS SETTING #
####################################################
Loss:
SeqCELoss:
enable: True
params:
loss_type: ce
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 32
pin_memory: False
num_workers: 4 # FIXME: set > 0 may stuck on macos
remove_short_samples: False
allow_cache: True
grad_norm: 1.0
###########################################################
# INTERVAL SETTING #
###########################################################
train_max_steps: 1000000 # Number of training steps.
save_interval_steps: 20000 # Interval steps to save checkpoint.
eval_interval_steps: 10000 # Interval steps to evaluate the network.
log_interval_steps: 1000 # Interval steps to record the training log.
###########################################################
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
import numpy as np
from scipy.io import wavfile
# TODO: add your own data type here as you need.
DATA_TYPE_DICT = {
"txt": {
"load_func": np.loadtxt,
"desc": "plain txt file or readable by np.loadtxt",
},
"wav": {
"load_func": lambda x: wavfile.read(x)[1],
"desc": "wav file or readable by soundfile.read",
},
"npy": {
"load_func": np.load,
"desc": "any .npy format file",
},
# PCM data type can be loaded by binary format
"bin_f32": {
"load_func": lambda x: np.fromfile(x, dtype=np.float32),
"desc": "binary file with float32 format",
},
"bin_f64": {
"load_func": lambda x: np.fromfile(x, dtype=np.float64),
"desc": "binary file with float64 format",
},
"bin_i32": {
"load_func": lambda x: np.fromfile(x, dtype=np.int32),
"desc": "binary file with int32 format",
},
"bin_i16": {
"load_func": lambda x: np.fromfile(x, dtype=np.int16),
"desc": "binary file with int16 format",
},
}
import os
import torch
import glob
import logging
from multiprocessing import Manager
import librosa
import numpy as np
import random
import functools
from tqdm import tqdm
import math
from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit, emotion_types
from scipy.stats import betabinom
DATASET_RANDOM_SEED = 1234
torch.multiprocessing.set_sharing_strategy("file_system")
@functools.lru_cache(maxsize=256)
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
P = phoneme_count
M = mel_count
x = np.arange(0, P)
mel_text_probs = []
for i in range(1, M + 1):
a, b = scaling * i, scaling * (M + 1 - i)
rv = betabinom(P, a, b)
mel_i_prob = rv.pmf(x)
mel_text_probs.append(mel_i_prob)
return torch.tensor(np.array(mel_text_probs))
class Padder(object):
def __init__(self):
super(Padder, self).__init__()
pass
def _pad1D(self, x, length, pad):
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad)
def _pad2D(self, x, length, pad):
return np.pad(
x, [(0, length - x.shape[0]), (0, 0)], mode="constant", constant_values=pad
)
def _pad_durations(self, duration, max_in_len, max_out_len):
framenum = np.sum(duration)
symbolnum = duration.shape[0]
if framenum < max_out_len:
padframenum = max_out_len - framenum
duration = np.insert(duration, symbolnum, values=padframenum, axis=0)
duration = np.insert(
duration,
symbolnum + 1,
values=[0] * (max_in_len - symbolnum - 1),
axis=0,
)
else:
if symbolnum < max_in_len:
duration = np.insert(
duration, symbolnum, values=[0] * (max_in_len - symbolnum), axis=0
)
return duration
def _round_up(self, x, multiple):
remainder = x % multiple
return x if remainder == 0 else x + multiple - remainder
def _prepare_scalar_inputs(self, inputs, max_len, pad):
return torch.from_numpy(
np.stack([self._pad1D(x, max_len, pad) for x in inputs])
)
def _prepare_targets(self, targets, max_len, pad):
return torch.from_numpy(
np.stack([self._pad2D(t, max_len, pad) for t in targets])
).float()
def _prepare_durations(self, durations, max_in_len, max_out_len):
return torch.from_numpy(
np.stack(
[self._pad_durations(t, max_in_len, max_out_len) for t in durations]
)
).long()
class Voc_Dataset(torch.utils.data.Dataset):
"""
provide (mel, audio) data pair
"""
def __init__(
self,
metafile,
root_dir,
config,
):
self.meta = []
self.config = config
self.sampling_rate = config["audio_config"]["sampling_rate"]
self.n_fft = config["audio_config"]["n_fft"]
self.hop_length = config["audio_config"]["hop_length"]
self.batch_max_steps = config["batch_max_steps"]
self.batch_max_frames = self.batch_max_steps // self.hop_length
self.aux_context_window = 0 # TODO: make it configurable
self.start_offset = self.aux_context_window
self.end_offset = -(self.batch_max_frames + self.aux_context_window)
self.nsf_enable = (
config["Model"]["Generator"]["params"].get("nsf_params", None) is not None
)
if self.nsf_enable:
self.nsf_norm_type = config["Model"]["Generator"]["params"][
"nsf_params"
].get("nsf_norm_type", '"mean_std')
if self.nsf_norm_type == "global":
self.nsf_f0_global_minimum = config["Model"]["Generator"]["params"][
"nsf_params"
].get("nsf_f0_global_minimum", 30.0)
self.nsf_f0_global_maximum = config["Model"]["Generator"]["params"][
"nsf_params"
].get("nsf_f0_global_maximum", 730.0)
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error("meta file not found: {}".format(meta_file))
raise ValueError(
"[Voc_Dataset] meta file: {} not found".format(meta_file)
)
if not os.path.exists(data_dir):
logging.error("data directory not found: {}".format(data_dir))
raise ValueError(
"[Voc_Dataset] data dir: {} not found".format(data_dir)
)
self.meta.extend(self.load_meta(meta_file, data_dir))
# Load from training data directory
if len(self.meta) == 0 and isinstance(root_dir, str):
wav_dir = os.path.join(root_dir, "wav")
mel_dir = os.path.join(root_dir, "mel")
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError("wav or mel directory not found")
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
elif len(self.meta) == 0 and isinstance(root_dir, list):
for d in root_dir:
wav_dir = os.path.join(d, "wav")
mel_dir = os.path.join(d, "mel")
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError("wav or mel directory not found")
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
self.allow_cache = config["allow_cache"]
if self.allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
@staticmethod
def gen_metafile(wav_dir, out_dir, split_ratio=0.98):
wav_files = glob.glob(os.path.join(wav_dir, "*.wav"))
frame_f0_dir = os.path.join(out_dir, "frame_f0")
frame_uv_dir = os.path.join(out_dir, "frame_uv")
mel_dir = os.path.join(out_dir, "mel")
random.seed(DATASET_RANDOM_SEED)
random.shuffle(wav_files)
num_train = int(len(wav_files) * split_ratio) - 1
with open(os.path.join(out_dir, "train.lst"), "w") as f:
for wav_file in wav_files[:num_train]:
index = os.path.splitext(os.path.basename(wav_file))[0]
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
f.write("{}\n".format(index))
with open(os.path.join(out_dir, "valid.lst"), "w") as f:
for wav_file in wav_files[num_train:]:
index = os.path.splitext(os.path.basename(wav_file))[0]
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
f.write("{}\n".format(index))
def load_meta(self, metafile, data_dir):
with open(metafile, "r") as f:
lines = f.readlines()
wav_dir = os.path.join(data_dir, "wav")
mel_dir = os.path.join(data_dir, "mel")
frame_f0_dir = os.path.join(data_dir, "frame_f0")
frame_uv_dir = os.path.join(data_dir, "frame_uv")
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError("wav or mel directory not found")
items = []
logging.info("Loading metafile...")
for name in tqdm(lines):
name = name.strip()
mel_file = os.path.join(mel_dir, name + ".npy")
wav_file = os.path.join(wav_dir, name + ".wav")
frame_f0_file = os.path.join(frame_f0_dir, name + ".npy")
frame_uv_file = os.path.join(frame_uv_dir, name + ".npy")
items.append((wav_file, mel_file, frame_f0_file, frame_uv_file))
return items
def load_meta_from_dir(self, wav_dir, mel_dir):
wav_files = glob.glob(os.path.join(wav_dir, "*.wav"))
items = []
for wav_file in wav_files:
mel_file = os.path.join(mel_dir, os.path.basename(wav_file))
if os.path.exists(mel_file):
items.append((wav_file, mel_file))
return items
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx]
f0_mean_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_mean.txt"
)
f0_std_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_std.txt"
)
wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0]
mel_data = np.load(mel_file)
if self.nsf_enable:
# denorm f0; default frame_f0_data using mean_std norm
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
f0_mean = np.loadtxt(f0_mean_file)
f0_std = np.loadtxt(f0_std_file)
frame_f0_data = frame_f0_data * f0_std + f0_mean
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data), axis=1)
# make sure mel_data length greater than batch_max_frames at least 1 frame
if mel_data.shape[0] <= self.batch_max_frames:
mel_data = np.concatenate(
(
mel_data,
np.zeros(
(
self.batch_max_frames - mel_data.shape[0] + 1,
mel_data.shape[1],
)
),
),
axis=0,
)
wav_cache = np.zeros(mel_data.shape[0] * self.hop_length, dtype=np.float32)
wav_cache[: len(wav_data)] = wav_data
wav_data = wav_cache
else:
# make sure the audio length and feature length are matched
wav_data = np.pad(wav_data, (0, self.n_fft), mode="reflect")
wav_data = wav_data[: len(mel_data) * self.hop_length]
assert len(mel_data) * self.hop_length == len(wav_data)
if self.allow_cache:
self.caches[idx] = (wav_data, mel_data)
return (wav_data, mel_data)
def collate_fn(self, batch):
wav_data, mel_data = [item[0] for item in batch], [item[1] for item in batch]
mel_lengths = [len(mel) for mel in mel_data]
start_frames = np.array(
[
np.random.randint(self.start_offset, length + self.end_offset)
for length in mel_lengths
]
)
wav_start = start_frames * self.hop_length
wav_end = wav_start + self.batch_max_steps
# aux window works as padding
mel_start = start_frames - self.aux_context_window
mel_end = mel_start + self.batch_max_frames + self.aux_context_window
wav_batch = [
x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end)
]
mel_batch = [
c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end)
]
# (B, 1, T)
wav_batch = torch.tensor(np.asarray(wav_batch), dtype=torch.float32).unsqueeze(
1
)
# (B, C, T)
mel_batch = torch.tensor(np.asarray(mel_batch), dtype=torch.float32).transpose(
2, 1
)
return wav_batch, mel_batch
def get_voc_datasets(
config,
root_dir,
split_ratio=0.98,
):
if isinstance(root_dir, str):
root_dir = [root_dir]
train_meta_lst = []
valid_meta_lst = []
for data_dir in root_dir:
train_meta = os.path.join(data_dir, "train.lst")
valid_meta = os.path.join(data_dir, "valid.lst")
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
Voc_Dataset.gen_metafile(
os.path.join(data_dir, "wav"), data_dir, split_ratio
)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = Voc_Dataset(
train_meta_lst,
root_dir,
config,
)
valid_dataset = Voc_Dataset(
valid_meta_lst[:50],
root_dir,
config,
)
return train_dataset, valid_dataset
# TODO(Yuxuan): refine the logic, you'd better not use emotion tag, it's ambiguous.
def get_fp_label(aug_ling_txt):
token_lst = aug_ling_txt.split(" ")
emo_lst = [token.strip("{}").split("$")[4] for token in token_lst]
syllable_lst = [token.strip("{}").split("$")[0] for token in token_lst]
# EOS token append
emo_lst.append(emotion_types[0])
syllable_lst.append("EOS")
# According to the original emotion tag, set each token's fp label.
if emo_lst[0] != emotion_types[3]:
emo_lst[0] = emotion_types[0]
emo_lst[1] = emotion_types[0]
for i in range(len(emo_lst) - 2, 1, -1):
if emo_lst[i] != emotion_types[3] and emo_lst[i - 1] != emotion_types[3]:
emo_lst[i] = emotion_types[0]
elif emo_lst[i] != emotion_types[3] and emo_lst[i - 1] == emotion_types[3]:
emo_lst[i] = emotion_types[3]
if syllable_lst[i - 2] == "ga":
emo_lst[i + 1] = emotion_types[1]
elif syllable_lst[i - 2] == "ge" and syllable_lst[i - 1] == "en_c":
emo_lst[i + 1] = emotion_types[2]
else:
emo_lst[i + 1] = emotion_types[4]
fp_label = []
for i in range(len(emo_lst)):
if emo_lst[i] == emotion_types[0]:
fp_label.append(0)
elif emo_lst[i] == emotion_types[1]:
fp_label.append(1)
elif emo_lst[i] == emotion_types[2]:
fp_label.append(2)
elif emo_lst[i] == emotion_types[3]:
continue
elif emo_lst[i] == emotion_types[4]:
fp_label.append(3)
else:
pass
return np.array(fp_label)
class AM_Dataset(torch.utils.data.Dataset):
"""
provide (ling, emo, speaker, mel) pair
"""
def __init__(
self,
config,
metafile,
root_dir,
allow_cache=False,
):
self.meta = []
self.config = config
self.with_duration = True
self.nsf_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"NSF", False
)
if self.nsf_enable:
self.nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get(
"nsf_norm_type", "mean_std"
)
if self.nsf_norm_type == "global":
self.nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"][
"params"
].get("nsf_f0_global_minimum", 30.0)
self.nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"][
"params"
].get("nsf_f0_global_maximum", 730.0)
self.se_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"SE", False
)
self.fp_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"FP", False
)
self.mas_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get(
"MAS", False
)
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error("meta file not found: {}".format(meta_file))
raise ValueError(
"[AM_Dataset] meta file: {} not found".format(meta_file)
)
if not os.path.exists(data_dir):
logging.error("data dir not found: {}".format(data_dir))
raise ValueError("[AM_Dataset] data dir: {} not found".format(data_dir))
self.meta.extend(self.load_meta(meta_file, data_dir))
self.allow_cache = allow_cache
self.ling_unit = KanTtsLinguisticUnit(config)
self.padder = Padder()
self.r = self.config["Model"]["KanTtsSAMBERT"]["params"]["outputs_per_step"]
# TODO: feat window
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
(
ling_txt,
mel_file,
dur_file,
f0_file,
energy_file,
frame_f0_file,
frame_uv_file,
aug_ling_txt,
se_path,
) = self.meta[idx]
f0_mean_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_mean.txt"
)
f0_std_file = os.path.join(
os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_std.txt"
)
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
mel_data = np.load(mel_file)
dur_data = np.load(dur_file) if dur_file is not None else None
f0_data = np.load(f0_file)
energy_data = np.load(energy_file)
se_data = np.load(se_path) if self.se_enable else None
# generate fp position label according to fpadd_meta
if self.fp_enable and aug_ling_txt is not None:
fp_label = get_fp_label(aug_ling_txt)
else:
fp_label = None
if self.with_duration:
attn_prior = None
else:
attn_prior = beta_binomial_prior_distribution(
len(ling_data[0]), mel_data.shape[0]
)
# Concat frame-level f0 and uv to mel_data
if self.nsf_enable:
# origin f0 data is mean std normed
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
# default f0 data is mean std normed; re-norm here
if self.nsf_norm_type == "global":
# denorm f0
f0_mean = np.loadtxt(f0_mean_file)
f0_std = np.loadtxt(f0_std_file)
f0_origin = frame_f0_data * f0_std + f0_mean
# renorm f0
frame_f0_data = (f0_origin - self.nsf_f0_global_minimum) / (
self.nsf_f0_global_maximum - self.nsf_f0_global_minimum
)
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data], axis=1)
if self.allow_cache:
self.caches[idx] = (
ling_data,
mel_data,
dur_data,
f0_data,
energy_data,
attn_prior,
fp_label,
se_data,
)
return (
ling_data,
mel_data,
dur_data,
f0_data,
energy_data,
attn_prior,
fp_label,
se_data,
)
def load_meta(self, metafile, data_dir):
with open(metafile, "r") as f:
lines = f.readlines()
aug_ling_dict = {}
if self.fp_enable:
add_fp_metafile = metafile.replace("fprm", "fpadd")
with open(add_fp_metafile, "r") as f:
fpadd_lines = f.readlines()
for line in fpadd_lines:
index, aug_ling_txt = line.split("\t")
aug_ling_dict[index] = aug_ling_txt
mel_dir = os.path.join(data_dir, "mel")
dur_dir = os.path.join(data_dir, "duration")
f0_dir = os.path.join(data_dir, "f0")
energy_dir = os.path.join(data_dir, "energy")
frame_f0_dir = os.path.join(data_dir, "frame_f0")
frame_uv_dir = os.path.join(data_dir, "frame_uv")
se_dir = os.path.join(data_dir, "se")
if self.mas_enable:
self.with_duration = False
else:
self.with_duration = os.path.exists(dur_dir)
items = []
logging.info("Loading metafile...")
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split("\t")
mel_file = os.path.join(mel_dir, index + ".npy")
if self.with_duration:
dur_file = os.path.join(dur_dir, index + ".npy")
else:
dur_file = None
f0_file = os.path.join(f0_dir, index + ".npy")
energy_file = os.path.join(energy_dir, index + ".npy")
frame_f0_file = os.path.join(frame_f0_dir, index + ".npy")
frame_uv_file = os.path.join(frame_uv_dir, index + ".npy")
aug_ling_txt = aug_ling_dict.get(index, None)
if self.fp_enable and aug_ling_txt is None:
logging.warning(f"Missing fpadd meta for {index}")
continue
se_path = os.path.join(se_dir, "se.npy")
if self.se_enable:
if not os.path.exists(se_path):
logging.warning("Missing se meta")
continue
items.append(
(
ling_txt,
mel_file,
dur_file,
f0_file,
energy_file,
frame_f0_file,
frame_uv_file,
aug_ling_txt,
se_path,
)
)
return items
def load_fpadd_meta(self, metafile):
with open(metafile, "r") as f:
lines = f.readlines()
items = []
logging.info("Loading fpadd metafile...")
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split("\t")
items.append((ling_txt,))
return items
@staticmethod
def gen_metafile(
raw_meta_file,
out_dir,
train_meta_file,
valid_meta_file,
badlist=None,
split_ratio=0.98,
se_enable=False,
):
with open(raw_meta_file, "r") as f:
lines = f.readlines()
se_dir = os.path.join(out_dir, "se")
frame_f0_dir = os.path.join(out_dir, "frame_f0")
frame_uv_dir = os.path.join(out_dir, "frame_uv")
mel_dir = os.path.join(out_dir, "mel")
duration_dir = os.path.join(out_dir, "duration")
random.seed(DATASET_RANDOM_SEED)
random.shuffle(lines)
num_train = int(len(lines) * split_ratio) - 1
with open(train_meta_file, "w") as f:
for line in lines[:num_train]:
index = line.split("\t")[0]
if badlist is not None and index in badlist:
continue
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
if os.path.exists(duration_dir) and not os.path.exists(
os.path.join(duration_dir, index + ".npy")
):
continue
if se_enable:
if os.path.exists(se_dir) and not os.path.exists(
os.path.join(se_dir, "se.npy")
):
continue
f.write(line)
with open(valid_meta_file, "w") as f:
for line in lines[num_train:]:
index = line.split("\t")[0]
if badlist is not None and index in badlist:
continue
if (
not os.path.exists(os.path.join(frame_f0_dir, index + ".npy"))
or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy"))
or not os.path.exists(os.path.join(mel_dir, index + ".npy"))
):
continue
if os.path.exists(duration_dir) and not os.path.exists(
os.path.join(duration_dir, index + ".npy")
):
continue
if se_enable:
if os.path.exists(se_dir) and not os.path.exists(
os.path.join(se_dir, "se.npy")
):
continue
f.write(line)
# TODO: implement collate_fn
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
if self.with_duration:
max_dur_length = max((x[2].shape[0] for x in batch)) + 1
lfeat_type_index = 0
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
if self.ling_unit.using_byte():
# for byte-based model only
inputs_byte_index = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict["input_lings"] = torch.stack([inputs_byte_index], dim=2)
else:
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
inputs_sy = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict["input_lings"] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2
)
# emotion category
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
data_dict["input_emotions"] = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# speaker category
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
if self.se_enable:
data_dict["input_speakers"] = self.padder._prepare_targets(
[x[7].repeat(len(x[0][0]), axis=0) for x in batch],
max_input_length,
0.0,
)
else:
data_dict["input_speakers"] = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# fp label category
if self.fp_enable:
data_dict["fp_label"] = self.padder._prepare_scalar_inputs(
[x[6] for x in batch],
max_input_length,
0,
).long()
data_dict["valid_input_lengths"] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
data_dict["valid_output_lengths"] = torch.as_tensor(
[len(x[1]) for x in batch], dtype=torch.long
)
max_output_length = torch.max(data_dict["valid_output_lengths"]).item()
max_output_round_length = self.padder._round_up(max_output_length, self.r)
data_dict["mel_targets"] = self.padder._prepare_targets(
[x[1] for x in batch], max_output_round_length, 0.0
)
if self.with_duration:
data_dict["durations"] = self.padder._prepare_durations(
[x[2] for x in batch], max_dur_length, max_output_round_length
)
else:
data_dict["durations"] = None
if self.with_duration:
if self.fp_enable:
feats_padding_length = max_dur_length
else:
feats_padding_length = max_input_length
else:
feats_padding_length = max_output_round_length
data_dict["pitch_contours"] = self.padder._prepare_scalar_inputs(
[x[3] for x in batch], feats_padding_length, 0.0
).float()
data_dict["energy_contours"] = self.padder._prepare_scalar_inputs(
[x[4] for x in batch], feats_padding_length, 0.0
).float()
if self.with_duration:
data_dict["attn_priors"] = None
else:
data_dict["attn_priors"] = torch.zeros(
len(batch), max_output_round_length, max_input_length
)
for i in range(len(batch)):
attn_prior = batch[i][5]
data_dict["attn_priors"][
i, : attn_prior.shape[0], : attn_prior.shape[1]
] = attn_prior
return data_dict
# TODO: implement get_am_datasets
def get_am_datasets(
metafile,
root_dir,
config,
allow_cache,
split_ratio=0.98,
se_enable=False,
):
if not isinstance(root_dir, list):
root_dir = [root_dir]
if not isinstance(metafile, list):
metafile = [metafile]
train_meta_lst = []
valid_meta_lst = []
fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
if fp_enable:
am_train_fn = "am_fprm_train.lst"
am_valid_fn = "am_fprm_valid.lst"
else:
am_train_fn = "am_train.lst"
am_valid_fn = "am_valid.lst"
for raw_metafile, data_dir in zip(metafile, root_dir):
train_meta = os.path.join(data_dir, am_train_fn)
valid_meta = os.path.join(data_dir, am_valid_fn)
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
AM_Dataset.gen_metafile(
raw_metafile, data_dir, train_meta, valid_meta, split_ratio, se_enable
)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = AM_Dataset(config, train_meta_lst, root_dir, allow_cache)
valid_dataset = AM_Dataset(config, valid_meta_lst[:50], root_dir, allow_cache)
return train_dataset, valid_dataset
class MaskingActor(object):
def __init__(self, mask_ratio=0.15):
super(MaskingActor, self).__init__()
self.mask_ratio = mask_ratio
pass
def _get_random_mask(self, length, p1=0.15):
mask = np.random.uniform(0, 1, length)
index = 0
while index < len(mask):
if mask[index] < p1:
mask[index] = 1
else:
mask[index] = 0
index += 1
return mask
def _input_bert_masking(
self,
sequence_array,
nb_symbol_category,
mask_symbol_id,
mask,
p2=0.8,
p3=0.1,
p4=0.1,
):
sequence_array_mask = sequence_array.copy()
mask_id = np.where(mask == 1)[0]
mask_len = len(mask_id)
rand = np.arange(mask_len)
np.random.shuffle(rand)
# [MASK]
mask_id_p2 = mask_id[rand[0 : int(math.floor(mask_len * p2))]]
if len(mask_id_p2) > 0:
sequence_array_mask[mask_id_p2] = mask_symbol_id
# rand
mask_id_p3 = mask_id[
rand[
int(math.floor(mask_len * p2)) : int(math.floor(mask_len * p2))
+ int(math.floor(mask_len * p3))
]
]
if len(mask_id_p3) > 0:
sequence_array_mask[mask_id_p3] = random.randint(0, nb_symbol_category - 1)
# ori
# do nothing
return sequence_array_mask
class BERT_Text_Dataset(torch.utils.data.Dataset):
"""
provide (ling, ling_sy_masked, bert_mask) pair
"""
def __init__(
self,
config,
metafile,
root_dir,
allow_cache=False,
):
self.meta = []
self.config = config
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error("meta file not found: {}".format(meta_file))
raise ValueError(
"[BERT_Text_Dataset] meta file: {} not found".format(meta_file)
)
if not os.path.exists(data_dir):
logging.error("data dir not found: {}".format(data_dir))
raise ValueError(
"[BERT_Text_Dataset] data dir: {} not found".format(data_dir)
)
self.meta.extend(self.load_meta(meta_file, data_dir))
self.allow_cache = allow_cache
self.ling_unit = KanTtsLinguisticUnit(config)
self.padder = Padder()
self.masking_actor = MaskingActor(
self.config["Model"]["KanTtsTextsyBERT"]["params"]["mask_ratio"]
)
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
def __len__(self):
return len(self.meta)
# TODO: implement __getitem__
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
ling_data = self.caches[idx][0]
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
return (ling_data, ling_sy_masked_data, bert_mask)
ling_txt = self.meta[idx]
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
if self.allow_cache:
self.caches[idx] = (ling_data,)
return (ling_data, ling_sy_masked_data, bert_mask)
def load_meta(self, metafile, data_dir):
with open(metafile, "r") as f:
lines = f.readlines()
items = []
logging.info("Loading metafile...")
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split("\t")
items.append((ling_txt))
return items
@staticmethod
def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98):
with open(raw_meta_file, "r") as f:
lines = f.readlines()
random.seed(DATASET_RANDOM_SEED)
random.shuffle(lines)
num_train = int(len(lines) * split_ratio) - 1
with open(os.path.join(out_dir, "bert_train.lst"), "w") as f:
for line in lines[:num_train]:
f.write(line)
with open(os.path.join(out_dir, "bert_valid.lst"), "w") as f:
for line in lines[num_train:]:
f.write(line)
def bert_masking(self, ling_data):
length = len(ling_data[0])
mask = self.masking_actor._get_random_mask(
length, p1=self.masking_actor.mask_ratio
)
mask[-1] = 0
# sy_masked
sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0]
ling_sy_masked_data = self.masking_actor._input_bert_masking(
ling_data[0],
self.ling_unit.get_unit_size()["sy"],
sy_mask_symbol_id,
mask,
p2=0.8,
p3=0.1,
p4=0.1,
)
return (mask, ling_sy_masked_data)
# TODO: implement collate_fn
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
lfeat_type = self.ling_unit._lfeat_type_list[0]
targets_sy = self.padder._prepare_scalar_inputs(
[x[0][0] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# sy masked
inputs_sy = self.padder._prepare_scalar_inputs(
[x[1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type = self.ling_unit._lfeat_type_list[1]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type = self.ling_unit._lfeat_type_list[2]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][2] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type = self.ling_unit._lfeat_type_list[3]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][3] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict["input_lings"] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2
)
data_dict["valid_input_lengths"] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
data_dict["targets"] = targets_sy
data_dict["bert_masks"] = self.padder._prepare_scalar_inputs(
[x[2] for x in batch], max_input_length, 0.0
)
return data_dict
def get_bert_text_datasets(
metafile,
root_dir,
config,
allow_cache,
split_ratio=0.98,
):
if not isinstance(root_dir, list):
root_dir = [root_dir]
if not isinstance(metafile, list):
metafile = [metafile]
train_meta_lst = []
valid_meta_lst = []
for raw_metafile, data_dir in zip(metafile, root_dir):
train_meta = os.path.join(data_dir, "bert_train.lst")
valid_meta = os.path.join(data_dir, "bert_valid.lst")
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
BERT_Text_Dataset.gen_metafile(raw_metafile, data_dir, split_ratio)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = BERT_Text_Dataset(config, train_meta_lst, root_dir, allow_cache)
valid_dataset = BERT_Text_Dataset(config, valid_meta_lst, root_dir, allow_cache)
return train_dataset, valid_dataset
import torch
from torch.nn.parallel import DistributedDataParallel
from kantts.models.hifigan.hifigan import ( # NOQA
Generator, # NOQA
MultiScaleDiscriminator, # NOQA
MultiPeriodDiscriminator, # NOQA
MultiSpecDiscriminator, # NOQA
)
import kantts
import kantts.train.scheduler
from kantts.models.sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT # NOQA
from kantts.utils.ling_unit.ling_unit import get_fpdict
from .pqmf import PQMF
def optimizer_builder(model_params, opt_name, opt_params):
opt_cls = getattr(torch.optim, opt_name)
optimizer = opt_cls(model_params, **opt_params)
return optimizer
def scheduler_builder(optimizer, sche_name, sche_params):
scheduler_cls = getattr(kantts.train.scheduler, sche_name)
scheduler = scheduler_cls(optimizer, **sche_params)
return scheduler
def hifigan_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model["discriminator"] = {}
optimizer["discriminator"] = {}
scheduler["discriminator"] = {}
for model_name in config["Model"].keys():
if model_name == "Generator":
params = config["Model"][model_name]["params"]
model["generator"] = Generator(**params).to(device)
optimizer["generator"] = optimizer_builder(
model["generator"].parameters(),
config["Model"][model_name]["optimizer"].get("type", "Adam"),
config["Model"][model_name]["optimizer"].get("params", {}),
)
scheduler["generator"] = scheduler_builder(
optimizer["generator"],
config["Model"][model_name]["scheduler"].get("type", "StepLR"),
config["Model"][model_name]["scheduler"].get("params", {}),
)
else:
params = config["Model"][model_name]["params"]
model["discriminator"][model_name] = globals()[model_name](**params).to(
device
)
optimizer["discriminator"][model_name] = optimizer_builder(
model["discriminator"][model_name].parameters(),
config["Model"][model_name]["optimizer"].get("type", "Adam"),
config["Model"][model_name]["optimizer"].get("params", {}),
)
scheduler["discriminator"][model_name] = scheduler_builder(
optimizer["discriminator"][model_name],
config["Model"][model_name]["scheduler"].get("type", "StepLR"),
config["Model"][model_name]["scheduler"].get("params", {}),
)
out_channels = config["Model"]["Generator"]["params"]["out_channels"]
if out_channels > 1:
model["pqmf"] = PQMF(subbands=out_channels, **config.get("pqmf", {})).to(device)
# FIXME: pywavelets buffer leads to gradient error in DDP training
# Solution: https://github.com/pytorch/pytorch/issues/22095
if distributed:
model["generator"] = DistributedDataParallel(
model["generator"],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
for model_name in model["discriminator"].keys():
model["discriminator"][model_name] = DistributedDataParallel(
model["discriminator"][model_name],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
return model, optimizer, scheduler
# TODO: some parsing
def sambert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model["KanTtsSAMBERT"] = KanTtsSAMBERT(
config["Model"]["KanTtsSAMBERT"]["params"]
).to(device)
fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False)
if fp_enable:
fp_dict = {
k: torch.from_numpy(v).long().unsqueeze(0).to(device)
for k, v in get_fpdict(config).items()
}
model["KanTtsSAMBERT"].fp_dict = fp_dict
optimizer["KanTtsSAMBERT"] = optimizer_builder(
model["KanTtsSAMBERT"].parameters(),
config["Model"]["KanTtsSAMBERT"]["optimizer"].get("type", "Adam"),
config["Model"]["KanTtsSAMBERT"]["optimizer"].get("params", {}),
)
scheduler["KanTtsSAMBERT"] = scheduler_builder(
optimizer["KanTtsSAMBERT"],
config["Model"]["KanTtsSAMBERT"]["scheduler"].get("type", "StepLR"),
config["Model"]["KanTtsSAMBERT"]["scheduler"].get("params", {}),
)
if distributed:
model["KanTtsSAMBERT"] = DistributedDataParallel(
model["KanTtsSAMBERT"], device_ids=[rank], output_device=rank
)
return model, optimizer, scheduler
def sybert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model["KanTtsTextsyBERT"] = KanTtsTextsyBERT(
config["Model"]["KanTtsTextsyBERT"]["params"]
).to(device)
optimizer["KanTtsTextsyBERT"] = optimizer_builder(
model["KanTtsTextsyBERT"].parameters(),
config["Model"]["KanTtsTextsyBERT"]["optimizer"].get("type", "Adam"),
config["Model"]["KanTtsTextsyBERT"]["optimizer"].get("params", {}),
)
scheduler["KanTtsTextsyBERT"] = scheduler_builder(
optimizer["KanTtsTextsyBERT"],
config["Model"]["KanTtsTextsyBERT"]["scheduler"].get("type", "StepLR"),
config["Model"]["KanTtsTextsyBERT"]["scheduler"].get("params", {}),
)
if distributed:
model["KanTtsTextsyBERT"] = DistributedDataParallel(
model["KanTtsTextsyBERT"], device_ids=[rank], output_device=rank
)
return model, optimizer, scheduler
# TODO: implement a builder for specific model
model_dict = {
"hifigan": hifigan_model_builder,
"sambert": sambert_model_builder,
"sybert": sybert_model_builder,
}
def model_builder(config, device="cpu", rank=0, distributed=False):
builder_func = model_dict[config["model_type"]]
model, optimizer, scheduler = builder_func(config, device, rank, distributed)
return model, optimizer, scheduler
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils import weight_norm, spectral_norm
from distutils.version import LooseVersion
from pytorch_wavelets import DWT1DForward
from .layers import (
Conv1d,
CausalConv1d,
ConvTranspose1d,
CausalConvTranspose1d,
ResidualBlock,
SourceModule,
)
from kantts.utils.audio_torch import stft
import copy
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
class Generator(torch.nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
channels=512,
kernel_size=7,
upsample_scales=(8, 8, 2, 2),
upsample_kernal_sizes=(16, 16, 4, 4),
resblock_kernel_sizes=(3, 7, 11),
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
repeat_upsample=True,
bias=True,
causal=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_weight_norm=True,
nsf_params=None,
):
super(Generator, self).__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, "Kernal size must be odd number."
assert len(upsample_scales) == len(upsample_kernal_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
self.upsample_scales = upsample_scales
self.repeat_upsample = repeat_upsample
self.num_upsamples = len(upsample_kernal_sizes)
self.num_kernels = len(resblock_kernel_sizes)
self.out_channels = out_channels
self.nsf_enable = nsf_params is not None
self.transpose_upsamples = torch.nn.ModuleList()
self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
self.conv_blocks = torch.nn.ModuleList()
conv_cls = CausalConv1d if causal else Conv1d
conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
self.conv_pre = conv_cls(
in_channels, channels, kernel_size, 1, padding=(kernel_size - 1) // 2
)
for i in range(len(upsample_kernal_sizes)):
self.transpose_upsamples.append(
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv_transposed_cls(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernal_sizes[i],
upsample_scales[i],
padding=(upsample_kernal_sizes[i] - upsample_scales[i]) // 2,
),
)
)
if repeat_upsample:
self.repeat_upsamples.append(
nn.Sequential(
nn.Upsample(mode="nearest", scale_factor=upsample_scales[i]),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv_cls(
channels // (2 ** i),
channels // (2 ** (i + 1)),
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
),
)
)
for j in range(len(resblock_kernel_sizes)):
self.conv_blocks.append(
ResidualBlock(
channels=channels // (2 ** (i + 1)),
kernel_size=resblock_kernel_sizes[j],
dilation=resblock_dilations[j],
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
causal=causal,
)
)
self.conv_post = conv_cls(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
if self.nsf_enable:
self.source_module = SourceModule(
nb_harmonics=nsf_params["nb_harmonics"],
upsample_ratio=np.cumprod(self.upsample_scales)[-1],
sampling_rate=nsf_params["sampling_rate"],
)
self.source_downs = nn.ModuleList()
self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
self.downsample_cum_rates = np.cumprod(self.downsample_rates)
for i, u in enumerate(self.downsample_cum_rates[::-1]):
if u == 1:
self.source_downs.append(
Conv1d(1, channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
conv_cls(
1,
channels // (2 ** (i + 1)),
u * 2,
u,
padding=u // 2,
)
)
def forward(self, x):
if self.nsf_enable:
mel = x[:, :-2, :]
pitch = x[:, -2:-1, :]
uv = x[:, -1:, :]
excitation = self.source_module(pitch, uv)
else:
mel = x
x = self.conv_pre(mel)
for i in range(self.num_upsamples):
# FIXME: sin function here seems to be causing issues
x = torch.sin(x) + x
rep = self.repeat_upsamples[i](x)
# transconv
up = self.transpose_upsamples[i](x)
if self.nsf_enable:
# Downsampling the excitation signal
e = self.source_downs[i](excitation)
# augment inputs with the excitation
x = rep + e + up[:, :, : rep.shape[-1]]
else:
x = rep + up[:, :, : rep.shape[-1]]
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.conv_blocks[i * self.num_kernels + j](x)
else:
xs += self.conv_blocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for layer in self.transpose_upsamples:
layer[-1].remove_weight_norm()
for layer in self.repeat_upsamples:
layer[-1].remove_weight_norm()
for layer in self.conv_blocks:
layer.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()
if self.nsf_enable:
self.source_module.remove_weight_norm()
for layer in self.source_downs:
layer.remove_weight_norm()
class PeriodDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
period=3,
kernel_sizes=[5, 3],
channels=32,
downsample_scales=[3, 3, 3, 3, 1],
max_downsample_channels=1024,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_spectral_norm=False,
):
super(PeriodDiscriminator, self).__init__()
self.period = period
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList()
in_chs, out_chs = in_channels, channels
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
)
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
)
in_chs = out_chs
out_chs = min(out_chs * 4, max_downsample_channels)
self.conv_post = nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(
self,
periods=[2, 3, 5, 7, 11],
discriminator_params={
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_spectral_norm": False,
},
):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params["period"] = period
self.discriminators += [PeriodDiscriminator(**params)]
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class ScaleDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=[15, 41, 5, 3],
channels=128,
max_downsample_channels=1024,
max_groups=16,
bias=True,
downsample_scales=[2, 2, 4, 4, 1],
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
use_spectral_norm=False,
):
super(ScaleDiscriminator, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_channels,
channels,
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
in_chs = channels
out_chs = channels
groups = 4
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
)
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
)
in_chs = out_chs
out_chs = min(in_chs * 2, max_downsample_channels)
groups = min(groups * 4, max_groups)
out_chs = min(in_chs * 2, max_downsample_channels)
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
self.conv_post = norm_f(
nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
)
)
def forward(self, x):
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(
self,
scales=3,
downsample_pooling="DWT",
# follow the official implementation setting
downsample_pooling_params={
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
discriminator_params={
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm=False,
):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
params["use_spectral_norm"] = True if i == 0 else False
self.discriminators += [ScaleDiscriminator(**params)]
if downsample_pooling == "DWT":
self.meanpools = nn.ModuleList(
[DWT1DForward(wave="db3", J=1), DWT1DForward(wave="db3", J=1)]
)
self.aux_convs = nn.ModuleList(
[
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
]
)
else:
self.meanpools = nn.ModuleList(
[nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
)
self.aux_convs = None
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
if i != 0:
if self.aux_convs is None:
y = self.meanpools[i - 1](y)
else:
yl, yh = self.meanpools[i - 1](y)
y = torch.cat([yl, yh[0]], dim=1)
y = self.aux_convs[i - 1](y)
y = F.leaky_relu(y, 0.1)
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class SpecDiscriminator(torch.nn.Module):
def __init__(
self,
channels=32,
init_kernel=15,
kernel_size=11,
stride=2,
use_spectral_norm=False,
fft_size=1024,
shift_size=120,
win_length=600,
window="hann_window",
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
):
super(SpecDiscriminator, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
# fft_size // 2 + 1
norm_f = weight_norm if not use_spectral_norm else spectral_norm
final_kernel = 5
post_conv_kernel = 3
blocks = 3 # TODO: remove hard code here
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
fft_size // 2 + 1,
channels,
(init_kernel, 1),
(1, 1),
padding=(init_kernel - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
for i in range(blocks):
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
)
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(final_kernel, 1),
(1, 1),
padding=(final_kernel - 1) // 2,
)
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
)
self.conv_post = norm_f(
nn.Conv2d(
channels,
1,
(post_conv_kernel, 1),
(1, 1),
padding=((post_conv_kernel - 1) // 2, 0),
)
)
self.register_buffer("window", getattr(torch, window)(win_length))
def forward(self, wav):
with torch.no_grad():
wav = torch.squeeze(wav, 1)
x_mag = stft(
wav, self.fft_size, self.shift_size, self.win_length, self.window
)
x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = x.squeeze(-1)
return x, fmap
class MultiSpecDiscriminator(torch.nn.Module):
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
discriminator_params={
"channels": 15,
"init_kernel": 1,
"kernel_sizes": 11,
"stride": 2,
"use_spectral_norm": False,
"window": "hann_window",
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
):
super(MultiSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths):
params = copy.deepcopy(discriminator_params)
params["fft_size"] = fft_size
params["shift_size"] = hop_size
params["win_length"] = win_length
self.discriminators += [SpecDiscriminator(**params)]
def forward(self, y):
y_d = []
fmap = []
for i, d in enumerate(self.discriminators):
x, x_map = d(y)
y_d.append(x)
fmap.append(x_map)
return y_d, fmap
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, remove_weight_norm
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal
from kantts.models.utils import init_weights
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class Conv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(Conv1d, self).__init__()
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
)
self.conv1d.apply(init_weights)
def forward(self, x):
x = self.conv1d(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class CausalConv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(CausalConv1d, self).__init__()
self.pad = (kernel_size - 1) * dilation
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
)
self.conv1d.apply(init_weights)
def forward(self, x): # bdt
x = F.pad(
x, (self.pad, 0, 0, 0, 0, 0), "constant"
) # described starting from the last dimension and moving forward.
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
x = self.conv1d(x)[:, :, : x.size(2)]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class ConvTranspose1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
super(ConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
output_padding=0,
)
)
self.deconv.apply(init_weights)
def forward(self, x):
return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
# FIXME: HACK to get shape right
class CausalConvTranspose1d(torch.nn.Module):
"""CausalConvTranspose1d module with customized initialization."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
"""Initialize CausalConvTranspose1d module."""
super(CausalConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
)
)
self.stride = stride
self.deconv.apply(init_weights)
self.pad = kernel_size - stride
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_in).
Returns:
Tensor: Output tensor (B, out_channels, T_out).
"""
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
return self.deconv(x)[:, :, : -self.pad]
# return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
class ResidualBlock(torch.nn.Module):
def __init__(
self,
channels,
kernel_size=3,
dilation=(1, 3, 5),
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
causal=False,
):
super(ResidualBlock, self).__init__()
assert kernel_size % 2 == 1, "Kernal size must be odd number."
conv_cls = CausalConv1d if causal else Conv1d
self.convs1 = nn.ModuleList(
[
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
)
for i in range(len(dilation))
]
)
self.convs2 = nn.ModuleList(
[
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
]
)
self.activation = getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = self.activation(x)
xt = c1(xt)
xt = self.activation(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
layer.remove_weight_norm()
for layer in self.convs2:
layer.remove_weight_norm()
class SourceModule(torch.nn.Module):
def __init__(
self, nb_harmonics, upsample_ratio, sampling_rate, alpha=0.1, sigma=0.003
):
super(SourceModule, self).__init__()
self.nb_harmonics = nb_harmonics
self.upsample_ratio = upsample_ratio
self.sampling_rate = sampling_rate
self.alpha = alpha
self.sigma = sigma
self.ffn = nn.Sequential(
weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
nn.Tanh(),
)
def forward(self, pitch, uv):
"""
:param pitch: [B, 1, frame_len], Hz
:param uv: [B, 1, frame_len] vuv flag
:return: [B, 1, sample_len]
"""
with torch.no_grad():
pitch_samples = F.interpolate(
pitch, scale_factor=(self.upsample_ratio), mode="nearest"
)
uv_samples = F.interpolate(
uv, scale_factor=(self.upsample_ratio), mode="nearest"
)
F_mat = torch.zeros(
(pitch_samples.size(0), self.nb_harmonics + 1, pitch_samples.size(-1))
).to(pitch_samples.device)
for i in range(self.nb_harmonics + 1):
F_mat[:, i : i + 1, :] = pitch_samples * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(
sample_shape=(pitch.size(0), self.nb_harmonics + 1, 1)
).to(F_mat.device)
phase_vec[:, 0, :] = 0
n_dist = Normal(loc=0.0, scale=self.sigma)
noise = n_dist.sample(
sample_shape=(
pitch_samples.size(0),
self.nb_harmonics + 1,
pitch_samples.size(-1),
)
).to(F_mat.device)
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
e_unvoice = self.alpha / 3 / self.sigma * noise
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
return self.ffn(e)
def remove_weight_norm(self):
remove_weight_norm(self.ffn[0])
# Copyright 2020 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Pseudo QMF modules."""
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, "The number of taps mush be even number."
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid="ignore"):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (np.arange(taps + 1) - 0.5 * taps)
)
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(torch.nn.Module):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# build analysis & synthesis filter coefficients
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = (
2
* h_proto
* np.cos(
(2 * k + 1)
* (np.pi / (2 * subbands))
* (np.arange(taps + 1) - (taps / 2))
+ (-1) ** k * np.pi / 4
)
)
h_synthesis[k] = (
2
* h_proto
* np.cos(
(2 * k + 1)
* (np.pi / (2 * subbands))
* (np.arange(taps + 1) - (taps / 2))
- (-1) ** k * np.pi / 4
)
)
# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
# register coefficients as beffer
self.register_buffer("analysis_filter", analysis_filter)
self.register_buffer("synthesis_filter", synthesis_filter)
# filter for downsampling & upsampling
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.subbands = subbands
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def analysis(self, x):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
# Not sure this is the correct way, it is better to check again.
# TODO(kan-bayashi): Understand the reconstruction procedure
x = F.conv_transpose1d(
x, self.updown_filter * self.subbands, stride=self.subbands
)
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, temperature, dropatt=0.0):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
self.dropatt = nn.Dropout(dropatt)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropatt(attn)
output = torch.bmm(attn, v)
return output, attn
class Prenet(nn.Module):
def __init__(self, in_units, prenet_units, out_units=0):
super(Prenet, self).__init__()
self.fcs = nn.ModuleList()
for in_dim, out_dim in zip([in_units] + prenet_units[:-1], prenet_units):
self.fcs.append(nn.Linear(in_dim, out_dim))
self.fcs.append(nn.ReLU())
self.fcs.append(nn.Dropout(0.5))
if out_units:
self.fcs.append(nn.Linear(prenet_units[-1], out_units))
def forward(self, input):
output = input
for layer in self.fcs:
output = layer(output)
return output
class MultiHeadSelfAttention(nn.Module):
""" Multi-Head SelfAttention module """
def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0):
super().__init__()
self.n_head = n_head
self.d_head = d_head
self.d_in = d_in
self.d_model = d_model
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_head, 0.5), dropatt=dropatt
)
self.fc = nn.Linear(n_head * d_head, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, input, mask=None):
d_head, n_head = self.d_head, self.n_head
sz_b, len_in, _ = input.size()
residual = input
x = self.layer_norm(input)
qkv = self.w_qkv(x)
q, k, v = qkv.chunk(3, -1)
q = q.view(sz_b, len_in, n_head, d_head)
k = k.view(sz_b, len_in, n_head, d_head)
v = v.view(sz_b, len_in, n_head, d_head)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in, d_head) # (n*b) x l x d
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in, d_head) # (n*b) x l x d
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in, d_head) # (n*b) x l x d
if mask is not None:
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_in, d_head)
output = (
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_in, -1)
) # b x l x (n*d)
output = self.dropout(self.fc(output))
if output.size(-1) == residual.size(-1):
output = output + residual
return output, attn
class PositionwiseConvFeedForward(nn.Module):
""" A two-feed-forward-layer module """
def __init__(self, d_in, d_hid, kernel_size=(3, 1), dropout_inner=0.1, dropout=0.1):
super().__init__()
# Use Conv1D
# position-wise
self.w_1 = nn.Conv1d(
d_in,
d_hid,
kernel_size=kernel_size[0],
padding=(kernel_size[0] - 1) // 2,
)
# position-wise
self.w_2 = nn.Conv1d(
d_hid,
d_in,
kernel_size=kernel_size[1],
padding=(kernel_size[1] - 1) // 2,
)
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout_inner = nn.Dropout(dropout_inner)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
output = x.transpose(1, 2)
output = F.relu(self.w_1(output))
if mask is not None:
output = output.masked_fill(mask.unsqueeze(1), 0)
output = self.dropout_inner(output)
output = self.w_2(output)
output = output.transpose(1, 2)
output = self.dropout(output)
output = output + residual
return output
class FFTBlock(nn.Module):
"""FFT Block"""
def __init__(
self,
d_in,
d_model,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(FFTBlock, self).__init__()
self.slf_attn = MultiHeadSelfAttention(
n_head, d_in, d_model, d_head, dropout=dropout, dropatt=dropout_attn
)
self.pos_ffn = PositionwiseConvFeedForward(
d_model, d_inner, kernel_size, dropout_inner=dropout_relu, dropout=dropout
)
def forward(self, input, mask=None, slf_attn_mask=None):
output, slf_attn = self.slf_attn(input, mask=slf_attn_mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
output = self.pos_ffn(output, mask=mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, slf_attn
class MultiHeadPNCAAttention(nn.Module):
""" Multi-Head Attention PNCA module """
def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0):
super().__init__()
self.n_head = n_head
self.d_head = d_head
self.d_model = d_model
self.d_mem = d_mem
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head)
self.fc_x = nn.Linear(n_head * d_head, d_model)
self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head)
self.fc_h = nn.Linear(n_head * d_head, d_model)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_head, 0.5), dropatt=dropatt
)
self.dropout = nn.Dropout(dropout)
def update_x_state(self, x):
d_head, n_head = self.d_head, self.n_head
sz_b, len_x, _ = x.size()
x_qkv = self.w_x_qkv(x)
x_q, x_k, x_v = x_qkv.chunk(3, -1)
x_q = x_q.view(sz_b, len_x, n_head, d_head)
x_k = x_k.view(sz_b, len_x, n_head, d_head)
x_v = x_v.view(sz_b, len_x, n_head, d_head)
x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
if self.x_state_size:
self.x_k = torch.cat([self.x_k, x_k], dim=1)
self.x_v = torch.cat([self.x_v, x_v], dim=1)
else:
self.x_k = x_k
self.x_v = x_v
self.x_state_size += len_x
return x_q, x_k, x_v
def update_h_state(self, h):
if self.h_state_size == h.size(1):
return None, None
d_head, n_head = self.d_head, self.n_head
# H
sz_b, len_h, _ = h.size()
h_kv = self.w_h_kv(h)
h_k, h_v = h_kv.chunk(2, -1)
h_k = h_k.view(sz_b, len_h, n_head, d_head)
h_v = h_v.view(sz_b, len_h, n_head, d_head)
self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
self.h_state_size += len_h
return h_k, h_v
def reset_state(self):
self.h_k = None
self.h_v = None
self.h_state_size = 0
self.x_k = None
self.x_v = None
self.x_state_size = 0
def forward(self, x, h, mask_x=None, mask_h=None):
residual = x
self.update_h_state(h)
x_q, x_k, x_v = self.update_x_state(self.layer_norm(x))
d_head, n_head = self.d_head, self.n_head
sz_b, len_in, _ = x.size()
# X
if mask_x is not None:
mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x ..
output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x)
output_x = output_x.view(n_head, sz_b, len_in, d_head)
output_x = (
output_x.permute(1, 2, 0, 3).contiguous().view(sz_b, len_in, -1)
) # b x l x (n*d)
output_x = self.fc_x(output_x)
# H
if mask_h is not None:
mask_h = mask_h.repeat(n_head, 1, 1)
output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h)
output_h = output_h.view(n_head, sz_b, len_in, d_head)
output_h = (
output_h.permute(1, 2, 0, 3).contiguous().view(sz_b, len_in, -1)
) # b x l x (n*d)
output_h = self.fc_h(output_h)
output = output_x + output_h
output = self.dropout(output)
output = output + residual
return output, attn_x, attn_h
class PNCABlock(nn.Module):
"""PNCA Block"""
def __init__(
self,
d_model,
d_mem,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(PNCABlock, self).__init__()
self.pnca_attn = MultiHeadPNCAAttention(
n_head, d_model, d_mem, d_head, dropout=dropout, dropatt=dropout_attn
)
self.pos_ffn = PositionwiseConvFeedForward(
d_model, d_inner, kernel_size, dropout_inner=dropout_relu, dropout=dropout
)
def forward(
self, input, memory, mask=None, pnca_x_attn_mask=None, pnca_h_attn_mask=None
):
output, pnca_attn_x, pnca_attn_h = self.pnca_attn(
input, memory, pnca_x_attn_mask, pnca_h_attn_mask
)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
output = self.pos_ffn(output, mask=mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, pnca_attn_x, pnca_attn_h
def reset_state(self):
self.pnca_attn.reset_state()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment