#!/usr/bin/env python3 import onnxruntime import zipfile from glob import glob try: from kantts.utils.ling_unit import text_to_mit_symbols as text_to_symbols from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit from kantts.models.sambert.kantts_sambert_divide import VarianceAdaptor2, MelPNCADecoder except ImportError: raise ImportError("Please install kantts.") try: from kantts.utils.log import logging_to_file except ImportError: raise ImportError("Please install kantts.") import os import sys import argparse import torch import soundfile as sf import yaml import logging import numpy as np import time ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # NOQA: E402 sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402 logging.basicConfig( # filename=os.path.join(stage_dir, 'stdout.log'), format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S", level=logging.INFO, ) def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_feature=None): if norm_type == 'mean_std': f0_mvn = f0_feature f0 = mel[:, -2] uv = mel[:, -1] uv[uv < uv_threshold] = 0.0 uv[uv >= uv_threshold] = 1.0 f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :] f0[f0 < f0_threshold] = f0_threshold mel[:, -2] = f0 mel[:, -1] = uv else: # global f0_global_max_min = f0_feature f0 = mel[:, -2] uv = mel[:, -1] uv[uv < uv_threshold] = 0.0 uv[uv >= uv_threshold] = 1.0 f0 = f0 * (f0_global_max_min[0] - f0_global_max_min[1]) + f0_global_max_min[1] f0[f0 < f0_threshold] = f0_threshold mel[:, -2] = f0 mel[:, -1] = uv return mel def get_mask_from_lengths(lengths, max_len=None): batch_size = lengths.shape[0] if max_len is None: max_len = torch.max(lengths).item() ids = ( torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) ) mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) return mask def hifigan_infer(input_mel, onnx_file, output_dir, config=None): if not torch.cuda.is_available(): device = torch.device("cpu") else: torch.backends.cudnn.benchmark = True device = torch.device("cuda", 0) # device = torch.device("cpu") if config is not None: with open(config, "r") as f: config = yaml.load(f, Loader=yaml.Loader) else: config_path = os.path.join( os.path.dirname(os.path.dirname(ckpt)), "config.yaml" ) if not os.path.exists(config_path): raise ValueError("config file not found: {}".format(config_path)) with open(config_path, "r") as f: config = yaml.load(f, Loader=yaml.Loader) # for key, value in config.items(): # logging.info(f"{key} = {value}") # check directory existence if not os.path.exists(output_dir): os.makedirs(output_dir) logging_to_file(os.path.join(output_dir, "stdout.log")) if os.path.isfile(input_mel): mel_lst = [input_mel] elif os.path.isdir(input_mel): mel_lst = glob(os.path.join(input_mel, "*.npy")) else: raise ValueError("input_mel should be a file or a directory") # model = load_model(ckpt_path, config) # logging.info(f"Loaded model parameters from {ckpt_path}.") # model.remove_weight_norm() # model = model.eval().to(device) # providers=['CUDAExecutionProvider', {'device_id': 1}] # providers=['CPUExecutionProvider'] # 这个是默认 providers = ['ROCMExecutionProvider'] ort_session = onnxruntime.InferenceSession(onnx_file, providers=providers) print(ort_session.get_providers()) # with torch.no_grad(): # pcm_len = 0 # i = 0 # 转onnx控制模型运行一次 # for mel in mel_lst: # if i>0: # break # i = i+1 # utt_id = os.path.splitext(os.path.basename(mel))[0] # mel_data = np.load(mel) # if model.nsf_enable: # mel_data = binarize(mel_data) # generate # mel_data = torch.tensor(mel_data, dtype=torch.float).to(device) # (T, C) -> (B, C, T) # mel_data = mel_data.transpose(1, 0).unsqueeze(0) # GPU预热 # for _ in range(10): # _ = model(mel_data) # 测速 # iterations = 100 # times = torch.zeros(iterations) # 存储每轮iteration的时间 # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # for iter in range(iterations): # starter.record() # _ = model(mel_data) # ender.record() # 同步GPU时间 # torch.cuda.synchronize() # cur_time = starter.elapsed_time(ender) # 计算时间 # times[iter] = cur_time # print(cur_time) # mean_time = times.mean().item() # print("hifigan pth file infer single time: {:.6f}".format(mean_time)) # y = model(mel_data) start = time.time() pcm_len = 0 for mel in mel_lst: start1 = time.time() utt_id = os.path.splitext(os.path.basename(mel))[0] logging.info("Inference sentence: {}".format(utt_id)) mel_data = np.load(mel) # generate mel_data = torch.tensor(mel_data, dtype=torch.float).to(device) # (T, C) -> (B, C, T) mel_data = mel_data.transpose(1, 0).unsqueeze(0) ort_inputs = {'mel_data': mel_data.cpu().numpy()} # GPU预热 for _ in range(50): _ = ort_session.run(['y'], ort_inputs) # 测速 iterations = 100 times = torch.zeros(iterations) # 存储每轮iteration的时间 starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) for iter in range(iterations): starter.record() ort_list = ort_session.run(['y'], ort_inputs) ender.record() # 同步GPU时间 torch.cuda.synchronize() cur_time = starter.elapsed_time(ender) # 计算时间 times[iter] = cur_time mean_time = times.mean().item() print("hifigan-onnx infer single time: {:.6f} ms".format(mean_time)) # logging.info("hifigan is running...") # ort_list = ort_session.run(['y'], ort_inputs) # PyTorch模型转换成 ONNX 格式 # x0 = mel_data # torch.onnx.export( # model, # x0, # "hifigan.onnx", # opset_version=11, # input_names=['mel_data'], # output_names=['y'] # ) # if hasattr(model, "pqmf"): # y = model.pqmf.synthesis(y) # print("hifigan infer single time: {:.6f}".format(mean_time)) # ort_y = ort_y.view(-1).cpu().numpy() ort_y = torch.from_numpy(ort_list[0]).view(-1).cpu().numpy() pcm_len += len(ort_y) # save as PCM 16 bit wav file # samplerate = 16000 sf.write( os.path.join(output_dir, f"{utt_id}_gen.wav"), ort_y, config["audio_config"]["sampling_rate"], "PCM_16", ) total_elapsed = time.time() - start1 print(f'Vocoder infer single time: {total_elapsed} seconds') rtf = (time.time() - start) / ( pcm_len / config["audio_config"]["sampling_rate"] ) # report average RTF logging.info( f"Finished generation of {len(mel_lst)} utterances (RTF = {rtf:.03f})." ) def am_infer_divide(sentence, text_encoder_onnx, variance_adaptor_ckpt, mel_decoder_ckpt, mel_postnet_onnx, output_dir, target_rate=1.0, se_file=None, config=None): if not torch.cuda.is_available(): device = torch.device("cpu") else: torch.backends.cudnn.benchmark = True device = torch.device("cuda", 0) # device = torch.device("cpu") if config is not None: with open(config, "r") as f: config = yaml.load(f, Loader=yaml.Loader) # else: # am_config_file = os.path.join( # os.path.dirname(os.path.dirname(ckpt)), "config.yaml" # ) # with open(am_config_file, "r") as f: # config = yaml.load(f, Loader=yaml.Loader) ling_unit = KanTtsLinguisticUnit(config) ling_unit_size = ling_unit.get_unit_size() config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size) se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False) se = np.load(se_file) if se_enable else None # nsf nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False) if nsf_enable: nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_norm_type", "mean_std") if nsf_norm_type == "mean_std": f0_mvn_file = os.path.join( os.path.dirname(os.path.dirname(ckpt)), "mvn.npy" ) f0_feature = np.load(f0_mvn_file) else: # global nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_minimum", 30.0) nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_maximum", 730.0) f0_feature = [nsf_f0_global_maximum, nsf_f0_global_minimum] # model, _, _ = model_builder(config, device) # fsnet = model["KanTtsSAMBERT"] logging.info("ort_sess is building...") providers = ['ROCMExecutionProvider'] logging.info("text_encoder_ort_sess is building...") text_enxoder_ort_sess = onnxruntime.InferenceSession(text_encoder_onnx, providers=providers) print(text_enxoder_ort_sess.get_providers()) # variance_adaptor_ort_sess = onnxruntime.InferenceSession(variance_adaptor_onnx, providers=providers) # mel_decoder_ort_sess = onnxruntime.InferenceSession(mel_decoder_onnx, providers=providers) logging.info("mel_postnet_ort_sess is building...") mel_postnet_ort_sess = onnxruntime.InferenceSession(mel_postnet_onnx, providers=providers) # variance_adaptor部分不用onnx,用pt variance_adaptor = VarianceAdaptor2(config["Model"]["KanTtsSAMBERT"]["params"]).to(device) logging.info("Loading checkpoint: {}".format(variance_adaptor_ckpt)) variance_adaptor_state_dict = torch.load(variance_adaptor_ckpt) variance_adaptor.load_state_dict(variance_adaptor_state_dict, strict=False) # mel_decoder部分不用onnx,用pt mel_decoder =MelPNCADecoder(config["Model"]["KanTtsSAMBERT"]["params"]).to(device) logging.info("Loading checkpoint: {}".format(mel_decoder_ckpt)) mel_decoder_state_dict = torch.load(mel_decoder_ckpt) mel_decoder.load_state_dict(mel_decoder_state_dict, strict=False) results_dir = os.path.join(output_dir, "feat") os.makedirs(results_dir, exist_ok=True) # fsnet.eval() # i = 0 # 转onnx控制模型运行一次 with open(sentence, encoding="utf-8") as f: for line in f: # if i > 0: # break # i = i + 1 start = time.time() line = line.strip().split("\t") logging.info("Inference sentence: {}".format(line[0])) mel_path = "%s/%s_mel.npy" % (results_dir, line[0]) dur_path = "%s/%s_dur.txt" % (results_dir, line[0]) f0_path = "%s/%s_f0.txt" % (results_dir, line[0]) energy_path = "%s/%s_energy.txt" % (results_dir, line[0]) with torch.no_grad(): # mel, mel_post, dur, f0, energy = am_synthesis(line[1], fsnet, ling_unit, device, se=se) inputs_feat_lst = ling_unit.encode_symbol_sequence(line[1]) inputs_feat_index = 0 if ling_unit.using_byte(): inputs_byte_index = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device) ) inputs_ling = torch.stack([inputs_byte_index], dim=-1).unsqueeze(0) else: inputs_sy = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device) ) inputs_feat_index = inputs_feat_index + 1 inputs_tone = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device) ) inputs_feat_index = inputs_feat_index + 1 inputs_syllable = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device) ) inputs_feat_index = inputs_feat_index + 1 inputs_ws = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device) ) inputs_ling = torch.stack( [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], dim=-1 ).unsqueeze(0) inputs_feat_index = inputs_feat_index + 1 inputs_emo = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]) .long() .to(device) .unsqueeze(0) ) inputs_feat_index = inputs_feat_index + 1 se_enable = False if se is None else True if se_enable: inputs_spk = ( torch.from_numpy(se.repeat(len(inputs_feat_lst[inputs_feat_index]), axis=0)) .float() .to(device) .unsqueeze(0)[:, :-1, :] ) else: inputs_spk = ( torch.from_numpy(inputs_feat_lst[inputs_feat_index]) .long() .to(device) .unsqueeze(0)[:, :-1] ) inputs_len = ( torch.zeros(1).long().to(device) + inputs_emo.size(1) - 1 ) # minus 1 for "~" inputs_ling = inputs_ling[:, :-1, :] inputs_emotion = inputs_emo[:, :-1] inputs_speaker = inputs_spk inputs_lengths = inputs_len batch_size = inputs_ling.size(0) inputs_ling_masks = get_mask_from_lengths(inputs_lengths, max_len=inputs_ling.size(1)) text_enxoder_inputs = {'inputs_ling': inputs_ling.cpu().numpy(), 'inputs_emotion': inputs_emotion.cpu().numpy(), 'inputs_speaker': inputs_speaker.cpu().numpy(), 'inputs_ling_masks': inputs_ling_masks.cpu().numpy(), } # # GPU预热 # for _ in range(50): # ( # _0, # _1, # _2, # _3 # ) = text_enxoder_ort_sess.run(['text_hid', # 'ling_embedding', # 'emo_hid', # 'spk_hid'], text_enxoder_inputs) # _ = fsnet( # inputs_ling[:, :-1, :], # inputs_emo[:, :-1], # inputs_spk, # inputs_len,) # inputs_ling = inputs_ling[:, :-1, :] # inputs_emotion = inputs_emo[:, :-1] # inputs_speaker = inputs_spk # inputs_lengths = inputs_len # 开始text_encoder # text_hid, ling_embedding, emo_hid, spk_hid = text_encoder( # inputs_ling, # inputs_emotion, # inputs_speaker, # inputs_ling_masks=inputs_ling_masks, # # return_attns=True) # # 测速 # iterations = 100 # times = torch.zeros(iterations) # 存储每轮iteration的时间 # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # for iter in range(iterations): # starter.record() # # _ = fsnet( # # inputs_ling[:, :-1, :], # # inputs_emo[:, :-1], # # inputs_spk, # # inputs_len,) # # logging.info("text_encoder is running...") # # text_enxoder_inputs = {'inputs_ling': inputs_ling.cpu().numpy(), # # 'inputs_emotion': inputs_emotion.cpu().numpy(), # # 'inputs_speaker': inputs_speaker.cpu().numpy(), # # 'inputs_ling_masks': inputs_ling_masks.cpu().numpy(), # # } # ( # text_hid, # ling_embedding, # emo_hid, # spk_hid # ) = text_enxoder_ort_sess.run(['text_hid', # 'ling_embedding', # 'emo_hid', # 'spk_hid'], text_enxoder_inputs # ) # ender.record() # # 同步GPU时间 # torch.cuda.synchronize() # cur_time = starter.elapsed_time(ender) # 计算时间 # times[iter] = cur_time # mean_time = times.mean().item() # print("text_enxoder-onnx single time: {:.6f} ms".format(mean_time)) ( text_hid, ling_embedding, emo_hid, spk_hid ) = text_enxoder_ort_sess.run( ['text_hid', 'ling_embedding', 'emo_hid', 'spk_hid'], text_enxoder_inputs ) inter_lengths = inputs_lengths inter_masks = get_mask_from_lengths(inter_lengths, max_len=text_hid.shape[1]) # output_masks = None # logging.info("variance_adaptor is running...") # # 测速 # iterations = 100 # times = torch.zeros(iterations) # 存储每轮iteration的时间 # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # for iter in range(iterations): # starter.record() # # 开始variance adaptorpt # ( # LR_text_outputs,LR_emo_outputs, # LR_spk_outputs, # LR_length_rounded, # log_duration_predictions, # pitch_predictions, # energy_predictions, # ) = variance_adaptor( # torch.from_numpy(text_hid).to(device), # torch.from_numpy(emo_hid).to(device), # torch.from_numpy(spk_hid).to(device), # masks=inter_masks, # # output_masks=output_masks, # # duration_targets=None, # # pitch_targets=None, # # energy_targets=None, # ) # ender.record() # # 同步GPU时间 # torch.cuda.synchronize() # cur_time = starter.elapsed_time(ender) # 计算时间 # times[iter] = cur_time # mean_time = times.mean().item() # print("variance_adaptor-pytorch single time: {:.6f} ms".format(mean_time)) # 开始variance adaptorpt ( LR_text_outputs,LR_emo_outputs, LR_spk_outputs, LR_length_rounded, log_duration_predictions, pitch_predictions, energy_predictions, ) = variance_adaptor( torch.from_numpy(text_hid).to(device), torch.from_numpy(emo_hid).to(device), torch.from_numpy(spk_hid).to(device), scale=1/target_rate, masks=inter_masks, # output_masks=output_masks, # duration_targets=None, # pitch_targets=None, # energy_targets=None, ) # variance_adaptor_inputs = {'text_hid': text_hid, # 'emo_hid': emo_hid, # 'spk_hid': spk_hid, # 'inter_masks': inter_masks.cpu().numpy(), # } # ( # LR_text_outputs, LR_emo_outputs, # LR_spk_outputs, # LR_length_rounded, # log_duration_predictions, # pitch_predictions, # energy_predictions, # ) = variance_adaptor_ort_sess.run(['LR_text_outputs', # 'LR_emo_outputs', # 'LR_spk_outputs', # 'LR_length_rounded', # 'log_duration_predictions', # 'pitch_predictions', # 'energy_predictions'], variance_adaptor_inputs) output_masks = get_mask_from_lengths(LR_length_rounded, max_len=LR_text_outputs.shape[1]) # lfr_masks = None outputs_per_step = config["Model"]["KanTtsSAMBERT"]["params"]["outputs_per_step"] r = outputs_per_step # LFR with the factor of outputs_per_step LFR_text_inputs = LR_text_outputs.contiguous().view(batch_size, -1, r * text_hid.shape[ -1]) # [1,153,32]->[1,51,96] LFR_emo_inputs = LR_emo_outputs.contiguous().view(batch_size, -1, r * emo_hid.shape[-1])[ :, :, : emo_hid.shape[-1]] LFR_spk_inputs = LR_spk_outputs.contiguous().view(batch_size, -1, r * spk_hid.shape[-1])[ :, :, : spk_hid.shape[-1]] # [1,153,192]->[1,51,192] memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], dim=2) x_band_width = int((torch.exp(log_duration_predictions) - 1).max() / r + 0.5) h_band_width = x_band_width # logging.info("mel_decoder is running...") # # 测速 # iterations = 100 # times = torch.zeros(iterations) # 存储每轮iteration的时间 # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # for iter in range(iterations): # starter.record() # # 开始mel_decoder # dec_outputs = mel_decoder( # memory, # x_band_width, # h_band_width, # # target=None, # # mask=lfr_masks, # # return_attns=True, # ) # ender.record() # # 同步GPU时间 # torch.cuda.synchronize() # cur_time = starter.elapsed_time(ender) # 计算时间 # times[iter] = cur_time # mean_time = times.mean().item() # print("mel_decoder-pytorch single time: {:.6f} ms".format(mean_time)) # 开始mel_decoder dec_outputs = mel_decoder( memory, x_band_width, h_band_width, # target=None, # mask=lfr_masks, # return_attns=True, ) # mel_decoder_inputs = {'memory': memory.cpu().numpy(), # 'x_band_width': np.array(x_band_width), # 'h_band_width': np.array(x_band_width), # } # dec_outputs = mel_decoder_ort_sess.run(['dec_outputs'], mel_decoder_inputs) d_mel = config["Model"]["KanTtsSAMBERT"]["params"]["num_mels"] # De-LFR with the factor of outputs_per_step dec_outputs = dec_outputs[0].contiguous().view(batch_size, -1, d_mel) # [1,51,246]->[1,153,82] if output_masks is not None: dec_outputs = dec_outputs.masked_fill(output_masks.unsqueeze(-1), 0) # logging.info("mel_postnet is running...") # 开始mel_postnet # postnet_outputs = mel_postnet(dec_outputs, output_masks) + dec_outputs # postnet_outputs = mel_postnet(dec_outputs, output_masks) mel_decoder_inputs = {'dec_outputs': dec_outputs.cpu().numpy(), 'output_masks': output_masks.cpu().numpy(), } # # 测速 # iterations = 100 # times = torch.zeros(iterations) # 存储每轮iteration的时间 # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # for iter in range(iterations): # starter.record() # postnet_outputs = mel_postnet_ort_sess.run(['postnet_outputs'], mel_decoder_inputs) # ender.record() # # 同步GPU时间 # torch.cuda.synchronize() # cur_time = starter.elapsed_time(ender) # 计算时间 # times[iter] = cur_time # mean_time = times.mean().item() # print("mel_postnet-onnx single time: {:.6f} ms".format(mean_time)) postnet_outputs = mel_postnet_ort_sess.run( ['postnet_outputs'], mel_decoder_inputs ) postnet_outputs = torch.from_numpy(postnet_outputs[0]).to(device) + dec_outputs if output_masks is not None: postnet_outputs = postnet_outputs.masked_fill(output_masks.unsqueeze(-1), 0) # 至此sambert forward开始返回值 # return torch.tensor(x_band_width), torch.tensor(h_band_width), dec_outputs, postnet_outputs,\ # LR_length_rounded, log_duration_predictions, pitch_predictions, energy_predictions valid_length = int(LR_length_rounded[0].item()) dec_outputs = dec_outputs[0, :valid_length, :].cpu().numpy() postnet_outputs = postnet_outputs[0, :valid_length, :].cpu().numpy() duration_predictions = ( (torch.exp(log_duration_predictions) - 1 + 0.5).long().squeeze().cpu().numpy()) pitch_predictions = pitch_predictions.squeeze().cpu().numpy() energy_predictions = energy_predictions.squeeze().cpu().numpy() logging.info("x_band_width:{}, h_band_width: {}".format(x_band_width, h_band_width)) # return ( # dec_outputs, # postnet_outputs, # duration_predictions, # pitch_predictions, # energy_predictions, # ) # 对应mel, mel_post, dur, f0, energy mel, mel_post, dur, f0, energy = dec_outputs, postnet_outputs, duration_predictions, pitch_predictions, energy_predictions if nsf_enable: mel_post = denorm_f0(mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature) np.save(mel_path, mel_post) np.savetxt(dur_path, dur) np.savetxt(f0_path, f0) np.savetxt(energy_path, energy) total_elapsed = time.time() - start print(f'AM infer single time: {total_elapsed} seconds') def concat_process(chunked_dir, output_dir): wav_files = sorted(glob(os.path.join(chunked_dir, "*.wav"))) sentence_sil = 0.28 # seconds end_sil = 0.05 # seconds cnt = 0 wav_concat = None main_id, sub_id = 0, 0 while cnt < len(wav_files): wav_file = os.path.join( chunked_dir, "{}_{}_mel_gen.wav".format(main_id, sub_id) ) if os.path.exists(wav_file): wav, sr = sf.read(wav_file) sentence_sil_samples = int(sentence_sil * sr) end_sil_samples = int(end_sil * sr) if sub_id == 0: wav_concat = wav else: wav_concat = np.concatenate( (wav_concat, np.zeros(sentence_sil_samples), wav), axis=0 ) sub_id += 1 cnt += 1 else: if wav_concat is not None: wav_concat = np.concatenate( (wav_concat, np.zeros(end_sil_samples)), axis=0 ) sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr) main_id += 1 sub_id = 0 wav_concat = None if cnt == len(wav_files): wav_concat = np.concatenate((wav_concat, np.zeros(end_sil_samples)), axis=0) sf.write(os.path.join(output_dir, f"{main_id}.wav"), wav_concat, sr) def text_to_wav_onnx( text_file, output_dir, resources_zip_file, text_encoder_onnx, variance_adaptor_pt, mel_decoder_onnx, mel_postnet_onnx, am_config_file, voc_onnx, voc_config_file, target_rate=1.0, speaker=None, se_file=None, lang="PinYin", ): os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, "res_wavs"), exist_ok=True) resource_root_dir = os.path.dirname(resources_zip_file) resource_dir = os.path.join(resource_root_dir, "resource") if not os.path.exists(resource_dir): logging.info("Extracting resources...") with zipfile.ZipFile(resources_zip_file, "r") as zip_ref: zip_ref.extractall(resource_root_dir) with open(text_file, "r") as text_data: texts = text_data.readlines() logging.info("Converting text to symbols...") # am_config = os.path.join(os.path.dirname(os.path.dirname(am_ckpt)), "config.yaml") with open(am_config_file, "r") as f: am_config = yaml.load(f, Loader=yaml.Loader) if speaker is None: speaker = am_config["linguistic_unit"]["speaker_list"].split(",")[0] symbols_lst = text_to_symbols(texts, resource_dir, speaker, lang) symbols_file = os.path.join(output_dir, "symbols.lst") with open(symbols_file, "w") as symbol_data: for symbol in symbols_lst: symbol_data.write(symbol) logging.info("AM is infering...") start = time.time() # am_infer(symbols_file, am_ckpt, output_dir, se_file) am_infer_divide(symbols_file, text_encoder_onnx, variance_adaptor_pt, mel_decoder_onnx, mel_postnet_onnx, output_dir, target_rate=target_rate, se_file=se_file, config=am_config_file ) total_elapsed = time.time() - start print(f'AM infer time: {total_elapsed} seconds') logging.info("Vocoder is infering...") start = time.time() hifigan_infer(os.path.join(output_dir, "feat"), voc_onnx, output_dir, config=voc_config_file) total_elapsed = time.time() - start print(f'Vocoder infer time: {total_elapsed} seconds') concat_process(output_dir, os.path.join(output_dir, "res_wavs")) logging.info("Text to wav finished!") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Text2wav_onnx") parser.add_argument("--txt", type=str, required=True, help="Path to text file") parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory") parser.add_argument("--res_zip", type=str, required=True, help="Path to resource zip file") # parser.add_argument("--am_ckpt", type=str, required=True, help="Path to am ckpt file") parser.add_argument("--text_encoder_onnx", type=str, required=True, help="Path to am -1 file") parser.add_argument("--variance_adaptor_pt", type=str, required=True, help="Path to am -2 file") parser.add_argument("--mel_decoder_pt", type=str, required=True, help="Path to am -3 file") parser.add_argument("--mel_postnet_onnx", type=str, required=True, help="Path to am -4 file") parser.add_argument("--am_config", type=str, required=True, help="Path to am config file") parser.add_argument("--voc_onnx", type=str, required=True, help="Path to voc onnx file") parser.add_argument("--voc_config", type=str, required=True, help="Path to voc config file") parser.add_argument("--target_rate", type=float, required=False, default=1.0, choices=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], help="Rate to final wav; optional: 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0") parser.add_argument("--speaker", type=str, required=False, default=None, help="The speaker name, default is the first speaker", ) parser.add_argument("--se_file", type=str, required=False, default=None, help="The speaker embedding file , default is None", ) parser.add_argument("--lang", type=str, default="PinYin", help="""The language of the text, default is PinYin, other options are: English, British, ZhHK, WuuShanghai, Sichuan, Indonesian, Malay, Filipino, Vietnamese, Korean, Russian """, ) args = parser.parse_args() start = time.time() text_to_wav_onnx( args.txt, args.output_dir, args.res_zip, # args.am_ckpt, args.text_encoder_onnx, args.variance_adaptor_pt, args.mel_decoder_pt, args.mel_postnet_onnx, args.am_config, args.voc_onnx, args.voc_config, args.target_rate, args.speaker, args.se_file, args.lang, ) total_elapsed = time.time() - start print(f'text to wave time: {total_elapsed} seconds')