import os import torch import glob import logging from multiprocessing import Manager import librosa import numpy as np import random import functools from tqdm import tqdm import math from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit, emotion_types from scipy.stats import betabinom DATASET_RANDOM_SEED = 1234 torch.multiprocessing.set_sharing_strategy("file_system") @functools.lru_cache(maxsize=256) def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0): P = phoneme_count M = mel_count x = np.arange(0, P) mel_text_probs = [] for i in range(1, M + 1): a, b = scaling * i, scaling * (M + 1 - i) rv = betabinom(P, a, b) mel_i_prob = rv.pmf(x) mel_text_probs.append(mel_i_prob) return torch.tensor(np.array(mel_text_probs)) class Padder(object): def __init__(self): super(Padder, self).__init__() pass def _pad1D(self, x, length, pad): return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad) def _pad2D(self, x, length, pad): return np.pad( x, [(0, length - x.shape[0]), (0, 0)], mode="constant", constant_values=pad ) def _pad_durations(self, duration, max_in_len, max_out_len): framenum = np.sum(duration) symbolnum = duration.shape[0] if framenum < max_out_len: padframenum = max_out_len - framenum duration = np.insert(duration, symbolnum, values=padframenum, axis=0) duration = np.insert( duration, symbolnum + 1, values=[0] * (max_in_len - symbolnum - 1), axis=0, ) else: if symbolnum < max_in_len: duration = np.insert( duration, symbolnum, values=[0] * (max_in_len - symbolnum), axis=0 ) return duration def _round_up(self, x, multiple): remainder = x % multiple return x if remainder == 0 else x + multiple - remainder def _prepare_scalar_inputs(self, inputs, max_len, pad): return torch.from_numpy( np.stack([self._pad1D(x, max_len, pad) for x in inputs]) ) def _prepare_targets(self, targets, max_len, pad): return torch.from_numpy( np.stack([self._pad2D(t, max_len, pad) for t in targets]) ).float() def _prepare_durations(self, durations, max_in_len, max_out_len): return torch.from_numpy( np.stack( [self._pad_durations(t, max_in_len, max_out_len) for t in durations] ) ).long() class Voc_Dataset(torch.utils.data.Dataset): """ provide (mel, audio) data pair """ def __init__( self, metafile, root_dir, config, ): self.meta = [] self.config = config self.sampling_rate = config["audio_config"]["sampling_rate"] self.n_fft = config["audio_config"]["n_fft"] self.hop_length = config["audio_config"]["hop_length"] self.batch_max_steps = config["batch_max_steps"] self.batch_max_frames = self.batch_max_steps // self.hop_length self.aux_context_window = 0 # TODO: make it configurable self.start_offset = self.aux_context_window self.end_offset = -(self.batch_max_frames + self.aux_context_window) self.nsf_enable = ( config["Model"]["Generator"]["params"].get("nsf_params", None) is not None ) if self.nsf_enable: self.nsf_norm_type = config["Model"]["Generator"]["params"][ "nsf_params" ].get("nsf_norm_type", '"mean_std') if self.nsf_norm_type == "global": self.nsf_f0_global_minimum = config["Model"]["Generator"]["params"][ "nsf_params" ].get("nsf_f0_global_minimum", 30.0) self.nsf_f0_global_maximum = config["Model"]["Generator"]["params"][ "nsf_params" ].get("nsf_f0_global_maximum", 730.0) if not isinstance(metafile, list): metafile = [metafile] if not isinstance(root_dir, list): root_dir = [root_dir] for meta_file, data_dir in zip(metafile, root_dir): if not os.path.exists(meta_file): logging.error("meta file not found: {}".format(meta_file)) raise ValueError( "[Voc_Dataset] meta file: {} not found".format(meta_file) ) if not os.path.exists(data_dir): logging.error("data directory not found: {}".format(data_dir)) raise ValueError( "[Voc_Dataset] data dir: {} not found".format(data_dir) ) self.meta.extend(self.load_meta(meta_file, data_dir)) # Load from training data directory if len(self.meta) == 0 and isinstance(root_dir, str): wav_dir = os.path.join(root_dir, "wav") mel_dir = os.path.join(root_dir, "mel") if not os.path.exists(wav_dir) or not os.path.exists(mel_dir): raise ValueError("wav or mel directory not found") self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir)) elif len(self.meta) == 0 and isinstance(root_dir, list): for d in root_dir: wav_dir = os.path.join(d, "wav") mel_dir = os.path.join(d, "mel") if not os.path.exists(wav_dir) or not os.path.exists(mel_dir): raise ValueError("wav or mel directory not found") self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir)) self.allow_cache = config["allow_cache"] if self.allow_cache: self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(self.meta))] @staticmethod def gen_metafile(wav_dir, out_dir, split_ratio=0.98): wav_files = glob.glob(os.path.join(wav_dir, "*.wav")) frame_f0_dir = os.path.join(out_dir, "frame_f0") frame_uv_dir = os.path.join(out_dir, "frame_uv") mel_dir = os.path.join(out_dir, "mel") random.seed(DATASET_RANDOM_SEED) random.shuffle(wav_files) num_train = int(len(wav_files) * split_ratio) - 1 with open(os.path.join(out_dir, "train.lst"), "w") as f: for wav_file in wav_files[:num_train]: index = os.path.splitext(os.path.basename(wav_file))[0] if ( not os.path.exists(os.path.join(frame_f0_dir, index + ".npy")) or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy")) or not os.path.exists(os.path.join(mel_dir, index + ".npy")) ): continue f.write("{}\n".format(index)) with open(os.path.join(out_dir, "valid.lst"), "w") as f: for wav_file in wav_files[num_train:]: index = os.path.splitext(os.path.basename(wav_file))[0] if ( not os.path.exists(os.path.join(frame_f0_dir, index + ".npy")) or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy")) or not os.path.exists(os.path.join(mel_dir, index + ".npy")) ): continue f.write("{}\n".format(index)) def load_meta(self, metafile, data_dir): with open(metafile, "r") as f: lines = f.readlines() wav_dir = os.path.join(data_dir, "wav") mel_dir = os.path.join(data_dir, "mel") frame_f0_dir = os.path.join(data_dir, "frame_f0") frame_uv_dir = os.path.join(data_dir, "frame_uv") if not os.path.exists(wav_dir) or not os.path.exists(mel_dir): raise ValueError("wav or mel directory not found") items = [] logging.info("Loading metafile...") for name in tqdm(lines): name = name.strip() mel_file = os.path.join(mel_dir, name + ".npy") wav_file = os.path.join(wav_dir, name + ".wav") frame_f0_file = os.path.join(frame_f0_dir, name + ".npy") frame_uv_file = os.path.join(frame_uv_dir, name + ".npy") items.append((wav_file, mel_file, frame_f0_file, frame_uv_file)) return items def load_meta_from_dir(self, wav_dir, mel_dir): wav_files = glob.glob(os.path.join(wav_dir, "*.wav")) items = [] for wav_file in wav_files: mel_file = os.path.join(mel_dir, os.path.basename(wav_file)) if os.path.exists(mel_file): items.append((wav_file, mel_file)) return items def __len__(self): return len(self.meta) def __getitem__(self, idx): if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx] f0_mean_file = os.path.join( os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_mean.txt" ) f0_std_file = os.path.join( os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_std.txt" ) wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0] mel_data = np.load(mel_file) if self.nsf_enable: # denorm f0; default frame_f0_data using mean_std norm frame_f0_data = np.load(frame_f0_file).reshape(-1, 1) f0_mean = np.loadtxt(f0_mean_file) f0_std = np.loadtxt(f0_std_file) frame_f0_data = frame_f0_data * f0_std + f0_mean frame_uv_data = np.load(frame_uv_file).reshape(-1, 1) mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data), axis=1) # make sure mel_data length greater than batch_max_frames at least 1 frame if mel_data.shape[0] <= self.batch_max_frames: mel_data = np.concatenate( ( mel_data, np.zeros( ( self.batch_max_frames - mel_data.shape[0] + 1, mel_data.shape[1], ) ), ), axis=0, ) wav_cache = np.zeros(mel_data.shape[0] * self.hop_length, dtype=np.float32) wav_cache[: len(wav_data)] = wav_data wav_data = wav_cache else: # make sure the audio length and feature length are matched wav_data = np.pad(wav_data, (0, self.n_fft), mode="reflect") wav_data = wav_data[: len(mel_data) * self.hop_length] assert len(mel_data) * self.hop_length == len(wav_data) if self.allow_cache: self.caches[idx] = (wav_data, mel_data) return (wav_data, mel_data) def collate_fn(self, batch): wav_data, mel_data = [item[0] for item in batch], [item[1] for item in batch] mel_lengths = [len(mel) for mel in mel_data] start_frames = np.array( [ np.random.randint(self.start_offset, length + self.end_offset) for length in mel_lengths ] ) wav_start = start_frames * self.hop_length wav_end = wav_start + self.batch_max_steps # aux window works as padding mel_start = start_frames - self.aux_context_window mel_end = mel_start + self.batch_max_frames + self.aux_context_window wav_batch = [ x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end) ] mel_batch = [ c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end) ] # (B, 1, T) wav_batch = torch.tensor(np.asarray(wav_batch), dtype=torch.float32).unsqueeze( 1 ) # (B, C, T) mel_batch = torch.tensor(np.asarray(mel_batch), dtype=torch.float32).transpose( 2, 1 ) return wav_batch, mel_batch def get_voc_datasets( config, root_dir, split_ratio=0.98, ): if isinstance(root_dir, str): root_dir = [root_dir] train_meta_lst = [] valid_meta_lst = [] for data_dir in root_dir: train_meta = os.path.join(data_dir, "train.lst") valid_meta = os.path.join(data_dir, "valid.lst") if not os.path.exists(train_meta) or not os.path.exists(valid_meta): Voc_Dataset.gen_metafile( os.path.join(data_dir, "wav"), data_dir, split_ratio ) train_meta_lst.append(train_meta) valid_meta_lst.append(valid_meta) train_dataset = Voc_Dataset( train_meta_lst, root_dir, config, ) valid_dataset = Voc_Dataset( valid_meta_lst[:50], root_dir, config, ) return train_dataset, valid_dataset # TODO(Yuxuan): refine the logic, you'd better not use emotion tag, it's ambiguous. def get_fp_label(aug_ling_txt): token_lst = aug_ling_txt.split(" ") emo_lst = [token.strip("{}").split("$")[4] for token in token_lst] syllable_lst = [token.strip("{}").split("$")[0] for token in token_lst] # EOS token append emo_lst.append(emotion_types[0]) syllable_lst.append("EOS") # According to the original emotion tag, set each token's fp label. if emo_lst[0] != emotion_types[3]: emo_lst[0] = emotion_types[0] emo_lst[1] = emotion_types[0] for i in range(len(emo_lst) - 2, 1, -1): if emo_lst[i] != emotion_types[3] and emo_lst[i - 1] != emotion_types[3]: emo_lst[i] = emotion_types[0] elif emo_lst[i] != emotion_types[3] and emo_lst[i - 1] == emotion_types[3]: emo_lst[i] = emotion_types[3] if syllable_lst[i - 2] == "ga": emo_lst[i + 1] = emotion_types[1] elif syllable_lst[i - 2] == "ge" and syllable_lst[i - 1] == "en_c": emo_lst[i + 1] = emotion_types[2] else: emo_lst[i + 1] = emotion_types[4] fp_label = [] for i in range(len(emo_lst)): if emo_lst[i] == emotion_types[0]: fp_label.append(0) elif emo_lst[i] == emotion_types[1]: fp_label.append(1) elif emo_lst[i] == emotion_types[2]: fp_label.append(2) elif emo_lst[i] == emotion_types[3]: continue elif emo_lst[i] == emotion_types[4]: fp_label.append(3) else: pass return np.array(fp_label) class AM_Dataset(torch.utils.data.Dataset): """ provide (ling, emo, speaker, mel) pair """ def __init__( self, config, metafile, root_dir, allow_cache=False, ): self.meta = [] self.config = config self.with_duration = True self.nsf_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get( "NSF", False ) if self.nsf_enable: self.nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get( "nsf_norm_type", "mean_std" ) if self.nsf_norm_type == "global": self.nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"][ "params" ].get("nsf_f0_global_minimum", 30.0) self.nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"][ "params" ].get("nsf_f0_global_maximum", 730.0) self.se_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get( "SE", False ) self.fp_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get( "FP", False ) self.mas_enable = self.config["Model"]["KanTtsSAMBERT"]["params"].get( "MAS", False ) if not isinstance(metafile, list): metafile = [metafile] if not isinstance(root_dir, list): root_dir = [root_dir] for meta_file, data_dir in zip(metafile, root_dir): if not os.path.exists(meta_file): logging.error("meta file not found: {}".format(meta_file)) raise ValueError( "[AM_Dataset] meta file: {} not found".format(meta_file) ) if not os.path.exists(data_dir): logging.error("data dir not found: {}".format(data_dir)) raise ValueError("[AM_Dataset] data dir: {} not found".format(data_dir)) self.meta.extend(self.load_meta(meta_file, data_dir)) self.allow_cache = allow_cache self.ling_unit = KanTtsLinguisticUnit(config) self.padder = Padder() self.r = self.config["Model"]["KanTtsSAMBERT"]["params"]["outputs_per_step"] # TODO: feat window if allow_cache: self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(self.meta))] def __len__(self): return len(self.meta) def __getitem__(self, idx): if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] ( ling_txt, mel_file, dur_file, f0_file, energy_file, frame_f0_file, frame_uv_file, aug_ling_txt, se_path, ) = self.meta[idx] f0_mean_file = os.path.join( os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_mean.txt" ) f0_std_file = os.path.join( os.path.dirname(os.path.dirname(frame_f0_file)), "f0", "f0_std.txt" ) ling_data = self.ling_unit.encode_symbol_sequence(ling_txt) mel_data = np.load(mel_file) dur_data = np.load(dur_file) if dur_file is not None else None f0_data = np.load(f0_file) energy_data = np.load(energy_file) se_data = np.load(se_path) if self.se_enable else None # generate fp position label according to fpadd_meta if self.fp_enable and aug_ling_txt is not None: fp_label = get_fp_label(aug_ling_txt) else: fp_label = None if self.with_duration: attn_prior = None else: attn_prior = beta_binomial_prior_distribution( len(ling_data[0]), mel_data.shape[0] ) # Concat frame-level f0 and uv to mel_data if self.nsf_enable: # origin f0 data is mean std normed frame_f0_data = np.load(frame_f0_file).reshape(-1, 1) # default f0 data is mean std normed; re-norm here if self.nsf_norm_type == "global": # denorm f0 f0_mean = np.loadtxt(f0_mean_file) f0_std = np.loadtxt(f0_std_file) f0_origin = frame_f0_data * f0_std + f0_mean # renorm f0 frame_f0_data = (f0_origin - self.nsf_f0_global_minimum) / ( self.nsf_f0_global_maximum - self.nsf_f0_global_minimum ) frame_uv_data = np.load(frame_uv_file).reshape(-1, 1) mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data], axis=1) if self.allow_cache: self.caches[idx] = ( ling_data, mel_data, dur_data, f0_data, energy_data, attn_prior, fp_label, se_data, ) return ( ling_data, mel_data, dur_data, f0_data, energy_data, attn_prior, fp_label, se_data, ) def load_meta(self, metafile, data_dir): with open(metafile, "r") as f: lines = f.readlines() aug_ling_dict = {} if self.fp_enable: add_fp_metafile = metafile.replace("fprm", "fpadd") with open(add_fp_metafile, "r") as f: fpadd_lines = f.readlines() for line in fpadd_lines: index, aug_ling_txt = line.split("\t") aug_ling_dict[index] = aug_ling_txt mel_dir = os.path.join(data_dir, "mel") dur_dir = os.path.join(data_dir, "duration") f0_dir = os.path.join(data_dir, "f0") energy_dir = os.path.join(data_dir, "energy") frame_f0_dir = os.path.join(data_dir, "frame_f0") frame_uv_dir = os.path.join(data_dir, "frame_uv") se_dir = os.path.join(data_dir, "se") if self.mas_enable: self.with_duration = False else: self.with_duration = os.path.exists(dur_dir) items = [] logging.info("Loading metafile...") for line in tqdm(lines): line = line.strip() index, ling_txt = line.split("\t") mel_file = os.path.join(mel_dir, index + ".npy") if self.with_duration: dur_file = os.path.join(dur_dir, index + ".npy") else: dur_file = None f0_file = os.path.join(f0_dir, index + ".npy") energy_file = os.path.join(energy_dir, index + ".npy") frame_f0_file = os.path.join(frame_f0_dir, index + ".npy") frame_uv_file = os.path.join(frame_uv_dir, index + ".npy") aug_ling_txt = aug_ling_dict.get(index, None) if self.fp_enable and aug_ling_txt is None: logging.warning(f"Missing fpadd meta for {index}") continue se_path = os.path.join(se_dir, "se.npy") if self.se_enable: if not os.path.exists(se_path): logging.warning("Missing se meta") continue items.append( ( ling_txt, mel_file, dur_file, f0_file, energy_file, frame_f0_file, frame_uv_file, aug_ling_txt, se_path, ) ) return items def load_fpadd_meta(self, metafile): with open(metafile, "r") as f: lines = f.readlines() items = [] logging.info("Loading fpadd metafile...") for line in tqdm(lines): line = line.strip() index, ling_txt = line.split("\t") items.append((ling_txt,)) return items @staticmethod def gen_metafile( raw_meta_file, out_dir, train_meta_file, valid_meta_file, badlist=None, split_ratio=0.98, se_enable=False, ): with open(raw_meta_file, "r") as f: lines = f.readlines() se_dir = os.path.join(out_dir, "se") frame_f0_dir = os.path.join(out_dir, "frame_f0") frame_uv_dir = os.path.join(out_dir, "frame_uv") mel_dir = os.path.join(out_dir, "mel") duration_dir = os.path.join(out_dir, "duration") random.seed(DATASET_RANDOM_SEED) random.shuffle(lines) num_train = int(len(lines) * split_ratio) - 1 with open(train_meta_file, "w") as f: for line in lines[:num_train]: index = line.split("\t")[0] if badlist is not None and index in badlist: continue if ( not os.path.exists(os.path.join(frame_f0_dir, index + ".npy")) or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy")) or not os.path.exists(os.path.join(mel_dir, index + ".npy")) ): continue if os.path.exists(duration_dir) and not os.path.exists( os.path.join(duration_dir, index + ".npy") ): continue if se_enable: if os.path.exists(se_dir) and not os.path.exists( os.path.join(se_dir, "se.npy") ): continue f.write(line) with open(valid_meta_file, "w") as f: for line in lines[num_train:]: index = line.split("\t")[0] if badlist is not None and index in badlist: continue if ( not os.path.exists(os.path.join(frame_f0_dir, index + ".npy")) or not os.path.exists(os.path.join(frame_uv_dir, index + ".npy")) or not os.path.exists(os.path.join(mel_dir, index + ".npy")) ): continue if os.path.exists(duration_dir) and not os.path.exists( os.path.join(duration_dir, index + ".npy") ): continue if se_enable: if os.path.exists(se_dir) and not os.path.exists( os.path.join(se_dir, "se.npy") ): continue f.write(line) # TODO: implement collate_fn def collate_fn(self, batch): data_dict = {} max_input_length = max((len(x[0][0]) for x in batch)) if self.with_duration: max_dur_length = max((x[2].shape[0] for x in batch)) + 1 lfeat_type_index = 0 lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] if self.ling_unit.using_byte(): # for byte-based model only inputs_byte_index = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() data_dict["input_lings"] = torch.stack([inputs_byte_index], dim=2) else: # pure linguistic info: sy|tone|syllable_flag|word_segment # sy inputs_sy = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # tone lfeat_type_index = lfeat_type_index + 1 lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] inputs_tone = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # syllable_flag lfeat_type_index = lfeat_type_index + 1 lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] inputs_syllable_flag = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # word_segment lfeat_type_index = lfeat_type_index + 1 lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] inputs_ws = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() data_dict["input_lings"] = torch.stack( [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2 ) # emotion category lfeat_type_index = lfeat_type_index + 1 lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] data_dict["input_emotions"] = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # speaker category lfeat_type_index = lfeat_type_index + 1 lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] if self.se_enable: data_dict["input_speakers"] = self.padder._prepare_targets( [x[7].repeat(len(x[0][0]), axis=0) for x in batch], max_input_length, 0.0, ) else: data_dict["input_speakers"] = self.padder._prepare_scalar_inputs( [x[0][lfeat_type_index] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # fp label category if self.fp_enable: data_dict["fp_label"] = self.padder._prepare_scalar_inputs( [x[6] for x in batch], max_input_length, 0, ).long() data_dict["valid_input_lengths"] = torch.as_tensor( [len(x[0][0]) - 1 for x in batch], dtype=torch.long ) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1 data_dict["valid_output_lengths"] = torch.as_tensor( [len(x[1]) for x in batch], dtype=torch.long ) max_output_length = torch.max(data_dict["valid_output_lengths"]).item() max_output_round_length = self.padder._round_up(max_output_length, self.r) data_dict["mel_targets"] = self.padder._prepare_targets( [x[1] for x in batch], max_output_round_length, 0.0 ) if self.with_duration: data_dict["durations"] = self.padder._prepare_durations( [x[2] for x in batch], max_dur_length, max_output_round_length ) else: data_dict["durations"] = None if self.with_duration: if self.fp_enable: feats_padding_length = max_dur_length else: feats_padding_length = max_input_length else: feats_padding_length = max_output_round_length data_dict["pitch_contours"] = self.padder._prepare_scalar_inputs( [x[3] for x in batch], feats_padding_length, 0.0 ).float() data_dict["energy_contours"] = self.padder._prepare_scalar_inputs( [x[4] for x in batch], feats_padding_length, 0.0 ).float() if self.with_duration: data_dict["attn_priors"] = None else: data_dict["attn_priors"] = torch.zeros( len(batch), max_output_round_length, max_input_length ) for i in range(len(batch)): attn_prior = batch[i][5] data_dict["attn_priors"][ i, : attn_prior.shape[0], : attn_prior.shape[1] ] = attn_prior return data_dict # TODO: implement get_am_datasets def get_am_datasets( metafile, root_dir, config, allow_cache, split_ratio=0.98, se_enable=False, ): if not isinstance(root_dir, list): root_dir = [root_dir] if not isinstance(metafile, list): metafile = [metafile] train_meta_lst = [] valid_meta_lst = [] fp_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("FP", False) if fp_enable: am_train_fn = "am_fprm_train.lst" am_valid_fn = "am_fprm_valid.lst" else: am_train_fn = "am_train.lst" am_valid_fn = "am_valid.lst" for raw_metafile, data_dir in zip(metafile, root_dir): train_meta = os.path.join(data_dir, am_train_fn) valid_meta = os.path.join(data_dir, am_valid_fn) if not os.path.exists(train_meta) or not os.path.exists(valid_meta): AM_Dataset.gen_metafile( raw_metafile, data_dir, train_meta, valid_meta, split_ratio, se_enable ) train_meta_lst.append(train_meta) valid_meta_lst.append(valid_meta) train_dataset = AM_Dataset(config, train_meta_lst, root_dir, allow_cache) valid_dataset = AM_Dataset(config, valid_meta_lst[:50], root_dir, allow_cache) return train_dataset, valid_dataset class MaskingActor(object): def __init__(self, mask_ratio=0.15): super(MaskingActor, self).__init__() self.mask_ratio = mask_ratio pass def _get_random_mask(self, length, p1=0.15): mask = np.random.uniform(0, 1, length) index = 0 while index < len(mask): if mask[index] < p1: mask[index] = 1 else: mask[index] = 0 index += 1 return mask def _input_bert_masking( self, sequence_array, nb_symbol_category, mask_symbol_id, mask, p2=0.8, p3=0.1, p4=0.1, ): sequence_array_mask = sequence_array.copy() mask_id = np.where(mask == 1)[0] mask_len = len(mask_id) rand = np.arange(mask_len) np.random.shuffle(rand) # [MASK] mask_id_p2 = mask_id[rand[0 : int(math.floor(mask_len * p2))]] if len(mask_id_p2) > 0: sequence_array_mask[mask_id_p2] = mask_symbol_id # rand mask_id_p3 = mask_id[ rand[ int(math.floor(mask_len * p2)) : int(math.floor(mask_len * p2)) + int(math.floor(mask_len * p3)) ] ] if len(mask_id_p3) > 0: sequence_array_mask[mask_id_p3] = random.randint(0, nb_symbol_category - 1) # ori # do nothing return sequence_array_mask class BERT_Text_Dataset(torch.utils.data.Dataset): """ provide (ling, ling_sy_masked, bert_mask) pair """ def __init__( self, config, metafile, root_dir, allow_cache=False, ): self.meta = [] self.config = config if not isinstance(metafile, list): metafile = [metafile] if not isinstance(root_dir, list): root_dir = [root_dir] for meta_file, data_dir in zip(metafile, root_dir): if not os.path.exists(meta_file): logging.error("meta file not found: {}".format(meta_file)) raise ValueError( "[BERT_Text_Dataset] meta file: {} not found".format(meta_file) ) if not os.path.exists(data_dir): logging.error("data dir not found: {}".format(data_dir)) raise ValueError( "[BERT_Text_Dataset] data dir: {} not found".format(data_dir) ) self.meta.extend(self.load_meta(meta_file, data_dir)) self.allow_cache = allow_cache self.ling_unit = KanTtsLinguisticUnit(config) self.padder = Padder() self.masking_actor = MaskingActor( self.config["Model"]["KanTtsTextsyBERT"]["params"]["mask_ratio"] ) if allow_cache: self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(self.meta))] def __len__(self): return len(self.meta) # TODO: implement __getitem__ def __getitem__(self, idx): if self.allow_cache and len(self.caches[idx]) != 0: ling_data = self.caches[idx][0] bert_mask, ling_sy_masked_data = self.bert_masking(ling_data) return (ling_data, ling_sy_masked_data, bert_mask) ling_txt = self.meta[idx] ling_data = self.ling_unit.encode_symbol_sequence(ling_txt) bert_mask, ling_sy_masked_data = self.bert_masking(ling_data) if self.allow_cache: self.caches[idx] = (ling_data,) return (ling_data, ling_sy_masked_data, bert_mask) def load_meta(self, metafile, data_dir): with open(metafile, "r") as f: lines = f.readlines() items = [] logging.info("Loading metafile...") for line in tqdm(lines): line = line.strip() index, ling_txt = line.split("\t") items.append((ling_txt)) return items @staticmethod def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98): with open(raw_meta_file, "r") as f: lines = f.readlines() random.seed(DATASET_RANDOM_SEED) random.shuffle(lines) num_train = int(len(lines) * split_ratio) - 1 with open(os.path.join(out_dir, "bert_train.lst"), "w") as f: for line in lines[:num_train]: f.write(line) with open(os.path.join(out_dir, "bert_valid.lst"), "w") as f: for line in lines[num_train:]: f.write(line) def bert_masking(self, ling_data): length = len(ling_data[0]) mask = self.masking_actor._get_random_mask( length, p1=self.masking_actor.mask_ratio ) mask[-1] = 0 # sy_masked sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0] ling_sy_masked_data = self.masking_actor._input_bert_masking( ling_data[0], self.ling_unit.get_unit_size()["sy"], sy_mask_symbol_id, mask, p2=0.8, p3=0.1, p4=0.1, ) return (mask, ling_sy_masked_data) # TODO: implement collate_fn def collate_fn(self, batch): data_dict = {} max_input_length = max((len(x[0][0]) for x in batch)) # pure linguistic info: sy|tone|syllable_flag|word_segment # sy lfeat_type = self.ling_unit._lfeat_type_list[0] targets_sy = self.padder._prepare_scalar_inputs( [x[0][0] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # sy masked inputs_sy = self.padder._prepare_scalar_inputs( [x[1] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # tone lfeat_type = self.ling_unit._lfeat_type_list[1] inputs_tone = self.padder._prepare_scalar_inputs( [x[0][1] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # syllable_flag lfeat_type = self.ling_unit._lfeat_type_list[2] inputs_syllable_flag = self.padder._prepare_scalar_inputs( [x[0][2] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() # word_segment lfeat_type = self.ling_unit._lfeat_type_list[3] inputs_ws = self.padder._prepare_scalar_inputs( [x[0][3] for x in batch], max_input_length, self.ling_unit._sub_unit_pad[lfeat_type], ).long() data_dict["input_lings"] = torch.stack( [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2 ) data_dict["valid_input_lengths"] = torch.as_tensor( [len(x[0][0]) - 1 for x in batch], dtype=torch.long ) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1 data_dict["targets"] = targets_sy data_dict["bert_masks"] = self.padder._prepare_scalar_inputs( [x[2] for x in batch], max_input_length, 0.0 ) return data_dict def get_bert_text_datasets( metafile, root_dir, config, allow_cache, split_ratio=0.98, ): if not isinstance(root_dir, list): root_dir = [root_dir] if not isinstance(metafile, list): metafile = [metafile] train_meta_lst = [] valid_meta_lst = [] for raw_metafile, data_dir in zip(metafile, root_dir): train_meta = os.path.join(data_dir, "bert_train.lst") valid_meta = os.path.join(data_dir, "bert_valid.lst") if not os.path.exists(train_meta) or not os.path.exists(valid_meta): BERT_Text_Dataset.gen_metafile(raw_metafile, data_dir, split_ratio) train_meta_lst.append(train_meta) valid_meta_lst.append(valid_meta) train_dataset = BERT_Text_Dataset(config, train_meta_lst, root_dir, allow_cache) valid_dataset = BERT_Text_Dataset(config, valid_meta_lst, root_dir, allow_cache) return train_dataset, valid_dataset