import sys
import torch
import os
import numpy as np
import argparse
import yaml
import logging
import time


ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH))  # NOQA: E402

try:
    from kantts.models import model_builder
    from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit
except ImportError:
    raise ImportError("Please install kantts.")

logging.basicConfig(
    #  filename=os.path.join(stage_dir, 'stdout.log'),
    format="%(asctime)s, %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)


from kantts.models.sambert.kantts_sambert_divide import TextEncoder, VarianceAdaptor, MelPNCADecoder, PostNet
from collections import OrderedDict


def denorm_f0(mel, f0_threshold=30, uv_threshold=0.6, norm_type='mean_std', f0_feature=None):
    if norm_type == 'mean_std':
        f0_mvn = f0_feature

        f0 = mel[:, -2]
        uv = mel[:, -1]

        uv[uv < uv_threshold] = 0.0
        uv[uv >= uv_threshold] = 1.0

        f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :]
        f0[f0 < f0_threshold] = f0_threshold

        mel[:, -2] = f0
        mel[:, -1] = uv
    else: # global
        f0_global_max_min = f0_feature

        f0 = mel[:, -2]
        uv = mel[:, -1]

        uv[uv < uv_threshold] = 0.0
        uv[uv >= uv_threshold] = 1.0

        f0 = f0 * (f0_global_max_min[0] - f0_global_max_min[1]) + f0_global_max_min[1]
        f0[f0 < f0_threshold] = f0_threshold

        mel[:, -2] = f0
        mel[:, -1] = uv

    return mel


def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths).item()

    ids = (
        torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
    )
    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

    return mask


def am_synthesis(symbol_seq, text_encoder, variance_adaptor, mel_decoder, mel_postnet, ling_unit, device, se=None):  # mel_decoder.r和mel_decoder.d_mel俩个参数无法通过onnx模型拿到
    inputs_feat_lst = ling_unit.encode_symbol_sequence(symbol_seq)

    inputs_feat_index = 0
    if ling_unit.using_byte():
        inputs_byte_index = (
            torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
        )
        inputs_ling = torch.stack([inputs_byte_index], dim=-1).unsqueeze(0)
    else:
        inputs_sy = (
            torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
        )
        inputs_feat_index = inputs_feat_index + 1
        inputs_tone = (
            torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
        )
        inputs_feat_index = inputs_feat_index + 1
        inputs_syllable = (
            torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
        )
        inputs_feat_index = inputs_feat_index + 1
        inputs_ws = (
            torch.from_numpy(inputs_feat_lst[inputs_feat_index]).long().to(device)
        )
        inputs_ling = torch.stack(
            [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], dim=-1
        ).unsqueeze(0)

    inputs_feat_index = inputs_feat_index + 1
    inputs_emo = (
        torch.from_numpy(inputs_feat_lst[inputs_feat_index])
        .long()
        .to(device)
        .unsqueeze(0)
    )

    inputs_feat_index = inputs_feat_index + 1
    se_enable = False if se is None else True
    
    if se_enable:
        inputs_spk = (
            torch.from_numpy(se.repeat(len(inputs_feat_lst[inputs_feat_index]), axis=0))
            .float()
            .to(device)
            .unsqueeze(0)[:, :-1, :]
        )
    else:
        inputs_spk = (
            torch.from_numpy(inputs_feat_lst[inputs_feat_index])
            .long()
            .to(device)
            .unsqueeze(0)[:, :-1]
        )

    inputs_len = (
        torch.zeros(1).to(device).long() + inputs_emo.size(1) - 1
    )  # minus 1 for "~"

    # GPU预热
    #for _ in range(10): 
        #res = fsnet(
        #inputs_ling[:, :-1, :],
        #inputs_emo[:, :-1],
        #inputs_spk,
        #inputs_len,)

    # 测速
    #iterations = 100 
    #times = torch.zeros(iterations) # 存储每轮iteration的时间
    #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    
    #for iter in range(iterations):
        #starter.record()
        #res = fsnet(
        #inputs_ling[:, :-1, :],
        #inputs_emo[:, :-1],
        #inputs_spk,
        #inputs_len,)
        #ender.record()
        # 同步GPU时间
        #torch.cuda.synchronize()
        #cur_time = starter.elapsed_time(ender) # 计算时间
        #times[iter] = cur_time
        #print(cur_time)
    #mean_time = times.mean().item()
    #print("sambert infer single time: {:.6f}".format(mean_time))
    
    # res = fsnet(
    #     inputs_ling[:, :-1, :],
    #     inputs_emo[:, :-1],
    #     inputs_spk,
    #     inputs_len,
    # )
    inputs_ling = inputs_ling[:, :-1, :]
    inputs_emotion = inputs_emo[:, :-1]
    inputs_speaker = inputs_spk
    inputs_lengths = inputs_len

    batch_size = inputs_ling.size(0)
    inputs_ling_masks = get_mask_from_lengths(inputs_lengths, max_len=inputs_ling.size(1))

    # # GPU预热
    # for _ in range(10):
    #     # _ = fsnet(
    #     # inputs_ling[:, :-1, :],
    #     # inputs_emo[:, :-1],
    #     # inputs_spk,
    #     # inputs_len,)
    #     text_hid, ling_embedding, emo_hid, spk_hid = text_encoder(
    #         inputs_ling,
    #         inputs_emotion,
    #         inputs_speaker,
    #         inputs_ling_masks=inputs_ling_masks,
    #         # return_attns=True  # 默认不返回，参数还是接受，为空即可
    #     )


    # # 测速
    # iterations = 100
    # times = torch.zeros(iterations) # 存储每轮iteration的时间
    # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # for iter in range(iterations):
    #     starter.record()

    #     text_hid, ling_embedding, emo_hid, spk_hid = text_encoder(
    #         inputs_ling,
    #         inputs_emotion,
    #         inputs_speaker,
    #         inputs_ling_masks=inputs_ling_masks,
    #         # return_attns=True  # 默认不返回，参数还是接受，为空即可
    #     )

    #     ender.record()
    #     # 同步GPU时间
    #     torch.cuda.synchronize()
    #     cur_time = starter.elapsed_time(ender) # 计算时间
    #     times[iter] = cur_time
    # mean_time = times.mean().item()
    # print("text_encoder-pytorch single time: {:.6f} ms".format(mean_time))

    # 开始text_encoder
    (
        text_hid,
        ling_embedding, 
        emo_hid, 
        spk_hid
    ) = text_encoder(
        inputs_ling,
        inputs_emotion,
        inputs_speaker,
        inputs_ling_masks=inputs_ling_masks,
        # return_attns=True  # 默认不返回，参数还是接受，为空即可
        )


    logging.info("text_encoder converting to onnx")
    # PyTorch模型转换成 ONNX 格式
    x0 = inputs_ling
    x1 = inputs_emotion
    x2 = inputs_speaker
    x3 = inputs_ling_masks
    dynamic_axes = {
        'inputs_ling': {1: 'text_encoder_input_dim1'},
        'inputs_emotion': {1: 'text_encoder_input_dim1'},
        'inputs_speaker': {1: 'text_encoder_input_dim1'},
        'inputs_ling_masks': {1: 'text_encoder_input_dim1'},
        'text_hid': {1: 'text_encoder_output_dim1'},
        'ling_embedding': {1: 'text_encoder_output_dim1'},
        'emo_hid': {1: 'text_encoder_output_dim1'},
        'spk_hid': {1: 'text_encoder_output_dim1'},
        }
    torch.onnx.export(
            text_encoder,
            (x0, x1, x2, x3),
            "sambert_onnx/text_encoder.onnx",
            opset_version=13,
            input_names=['inputs_ling','inputs_emotion','inputs_speaker', 'inputs_ling_masks'],
            output_names=['text_hid', 'ling_embedding', 'emo_hid', 'spk_hid'],
            dynamic_axes=dynamic_axes
        )


    inter_lengths = inputs_lengths

    inter_masks = get_mask_from_lengths(inter_lengths, max_len=text_hid.size(1))
    # output_masks = None

    # # 测速
    # iterations = 100
    # times = torch.zeros(iterations)  # 存储每轮iteration的时间
    # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # for iter in range(iterations):
    #     starter.record()

    #     (
    #         LR_text_outputs,
    #         LR_emo_outputs,
    #         LR_spk_outputs,
    #         LR_length_rounded,
    #         log_duration_predictions,
    #         pitch_predictions,
    #         energy_predictions,
    #     ) = variance_adaptor(
    #         text_hid,
    #         emo_hid,
    #         spk_hid,
    #         masks=inter_masks,
    #         # output_masks=output_masks,
    #         # duration_targets=None,
    #         # pitch_targets=None,
    #         # energy_targets=None,
    #     )

    #     ender.record()
    #     # 同步GPU时间
    #     torch.cuda.synchronize()
    #     cur_time = starter.elapsed_time(ender)  # 计算时间
    #     times[iter] = cur_time
    # mean_time = times.mean().item()
    # print("variance_adaptor-pytorch single time: {:.6f} ms".format(mean_time))



    # 开始variance_adaptor
    (
        LR_text_outputs,
        LR_emo_outputs,
        LR_spk_outputs,
        LR_length_rounded,
        log_duration_predictions,
        pitch_predictions,
        energy_predictions,
    ) = variance_adaptor(
        text_hid,
        emo_hid,
        spk_hid,
        masks=inter_masks,
        # output_masks=output_masks,
        # duration_targets=None,
        # pitch_targets=None,
        # energy_targets=None,
        )

    # logging.info("variance_adaptor converting to onnx")
    # # PyTorch模型转换成 ONNX 格式
    # x0 = text_hid
    # x1 = emo_hid
    # x2 = spk_hid
    # x3 = inter_masks
    # dynamic_axes = {
    #     'text_hid': {1: 'variance_adaptor_input_dim1'},
    #     'emo_hid': {1: 'variance_adaptor_input_dim1'},
    #     'spk_hid': {1: 'variance_adaptor_input_dim1'},
    #     'inter_masks': {1: 'variance_adaptor_input_dim1'},
    #     'LR_text_outputs': {1: 'variance_adaptor_output_dim1'},
    #     'LR_emo_outputs': {1: 'variance_adaptor_output_dim1'},
    #     'LR_spk_outputs': {1: 'variance_adaptor_output_dim1'},
    #     'log_duration_predictions': {1: 'variance_adaptor_output_dim1'},
    #     'pitch_predictions': {1: 'variance_adaptor_output_dim1'},
    #     'energy_predictions': {1: 'variance_adaptor_output_dim1'},
    #     }
    # torch.onnx.export(
    #         variance_adaptor,
    #         (x0, x1, x2, x3),
    #         "variance_adaptor.onnx",
    #         opset_version=13,
    #         input_names=['text_hid',
    #                     'emo_hid',
    #                     'spk_hid',
    #                     'inter_masks'],
    #         output_names=['LR_text_outputs',
    #                        'LR_emo_outputs',
    #                        'LR_spk_outputs',
    #                        'LR_length_rounded',
    #                        'log_duration_predictions',
    #                        'pitch_predictions',
    #                        'energy_predictions'],
    #         dynamic_axes=dynamic_axes
    #     )


    output_masks = get_mask_from_lengths(LR_length_rounded, max_len=LR_text_outputs.size(1))
    # lfr_masks = None

    # LFR with the factor of outputs_per_step
    LFR_text_inputs = LR_text_outputs.contiguous().view(batch_size, -1, mel_decoder.r * text_hid.shape[-1])  # [1,153,32]->[1,51,96]
    LFR_emo_inputs = LR_emo_outputs.contiguous().view(batch_size, -1, mel_decoder.r * emo_hid.shape[-1])[:, :, : emo_hid.shape[-1]]
    LFR_spk_inputs = LR_spk_outputs.contiguous().view(batch_size, -1, mel_decoder.r * spk_hid.shape[-1])[:, :, : spk_hid.shape[-1]]  # [1,153,192]->[1,51,192]

    memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], dim=2)

    x_band_width = int((torch.exp(log_duration_predictions) - 1).max() / mel_decoder.r + 0.5)
    h_band_width = x_band_width

    # # 测速
    # iterations = 100
    # times = torch.zeros(iterations)  # 存储每轮iteration的时间
    # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # for iter in range(iterations):
    #     starter.record()

    #     dec_outputs = mel_decoder(
    #         memory,
    #         x_band_width,
    #         h_band_width,
    #         # target=None,
    #         # masks=lfr_masks,
    #         # masks=None,
    #         # return_attns=True,
    #     )

    #     ender.record()
    #     # 同步GPU时间
    #     torch.cuda.synchronize()
    #     cur_time = starter.elapsed_time(ender)  # 计算时间
    #     times[iter] = cur_time
    # mean_time = times.mean().item()
    # print("mel_decoder-pytorch single time: {:.6f} ms".format(mean_time))


    # 开始mel_decoder
    dec_outputs = mel_decoder(
            memory,
            x_band_width,
            h_band_width,
            # target=None,
            # masks=lfr_masks,
            # masks=None,
            # return_attns=True,
        )
    

    # # PyTorch模型转换成 ONNX 格式
    # x0 = memory
    # x1 = x_band_width
    # x2 = h_band_width
    # dynamic_axes = {
    #     'memory': {1: 'mel_decoder_input_dim1'},
    #     'dec_outputs': {1: 'mel_decoder_output_dim1'},
    #     }
    # torch.onnx.export(
    #         mel_decoder,
    #         (x0, x1, x2),
    #         "mel_decoder.onnx",
    #         opset_version=13,
    #         input_names=['memory', 'x_band_width', 'h_band_width'],
    #         output_names=['dec_outputs'],
    #         dynamic_axes=dynamic_axes
    #     )
    

    # De-LFR with the factor of outputs_per_step
    dec_outputs = dec_outputs.contiguous().view(batch_size, -1, mel_decoder.d_mel)  # [1,51,246]->[1,153,82]
    if output_masks is not None:
        dec_outputs = dec_outputs.masked_fill(output_masks.unsqueeze(-1), 0)

    # 开始mel_postnet
    # postnet_outputs = mel_postnet(dec_outputs, output_masks) + dec_outputs

    # # 测速
    # iterations = 100
    # times = torch.zeros(iterations)  # 存储每轮iteration的时间
    # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    # for iter in range(iterations):
    #     starter.record()

    #     postnet_outputs = mel_postnet(dec_outputs, output_masks)

    #     ender.record()
    #     # 同步GPU时间
    #     torch.cuda.synchronize()
    #     cur_time = starter.elapsed_time(ender)  # 计算时间
    #     times[iter] = cur_time
    # mean_time = times.mean().item()
    # print("mel_postnet-pytorch single time: {:.6f} ms".format(mean_time))


    postnet_outputs = mel_postnet(dec_outputs, output_masks)

    # PyTorch模型转换成 ONNX 格式
    x0 = dec_outputs
    x1 = output_masks
    dynamic_axes = {
        'dec_outputs': {1: 'mel_postnet_input_dim1'},
        'output_masks': {1: 'mel_postnet_input_dim1'},
        'postnet_outputs': {1: 'mel_postnet_output_dim1'},
        }
    torch.onnx.export(
            mel_postnet,
            (x0, x1),
            "sambert_onnx/mel_postnet.onnx",
            opset_version=13,
            input_names=['dec_outputs', 'output_masks'],
            output_names=['postnet_outputs'],
            dynamic_axes=dynamic_axes
        )


    postnet_outputs = postnet_outputs + dec_outputs
    if output_masks is not None:
        postnet_outputs = postnet_outputs.masked_fill(output_masks.unsqueeze(-1), 0)
    
    # 至此sambert forward开始返回值
    # return torch.tensor(x_band_width), torch.tensor(h_band_width), dec_outputs, postnet_outputs,\
    #     LR_length_rounded, log_duration_predictions, pitch_predictions, energy_predictions


    valid_length = int(LR_length_rounded[0].item())
    dec_outputs = dec_outputs[0, :valid_length, :].cpu().numpy()
    postnet_outputs = postnet_outputs[0, :valid_length, :].cpu().numpy()
    duration_predictions = ((torch.exp(log_duration_predictions) - 1 + 0.5).long().squeeze().cpu().numpy())
    pitch_predictions = pitch_predictions.squeeze().cpu().numpy()
    energy_predictions = energy_predictions.squeeze().cpu().numpy()

    logging.info("x_band_width:{}, h_band_width: {}".format(x_band_width, h_band_width))

    return (
        dec_outputs,
        postnet_outputs,
        duration_predictions,
        pitch_predictions,
        energy_predictions,
    )



def am_infer_divide(sentence, ckpt, output_dir, se_file=None, config=None):
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        torch.backends.cudnn.benchmark = True
        device = torch.device("cuda", 0)

    # device = torch.device("cpu")

    if config is not None:
        with open(config, "r") as f:
            config = yaml.load(f, Loader=yaml.Loader)
    else:
        am_config_file = os.path.join(
            os.path.dirname(os.path.dirname(ckpt)), "config.yaml"
        )
        with open(am_config_file, "r") as f:
            config = yaml.load(f, Loader=yaml.Loader)

    ling_unit = KanTtsLinguisticUnit(config)
    ling_unit_size = ling_unit.get_unit_size()
    config["Model"]["KanTtsSAMBERT"]["params"].update(ling_unit_size)

    se_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("SE", False) 
    se = np.load(se_file) if se_enable else None

    # nsf
    nsf_enable = config["Model"]["KanTtsSAMBERT"]["params"].get("NSF", False) 
    if nsf_enable:
        nsf_norm_type = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_norm_type", "mean_std")
        if nsf_norm_type == "mean_std":
            f0_mvn_file = os.path.join(
                os.path.dirname(os.path.dirname(ckpt)), "mvn.npy"
            )
            f0_feature = np.load(f0_mvn_file)   
        else: # global
            nsf_f0_global_minimum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_minimum", 30.0) 
            nsf_f0_global_maximum = config["Model"]["KanTtsSAMBERT"]["params"].get("nsf_f0_global_maximum", 730.0) 
            f0_feature = [nsf_f0_global_maximum, nsf_f0_global_minimum]

    # model, _, _ = model_builder(config, device)
    
    # fsnet = model["KanTtsSAMBERT"]  # 字典根据键取值，得到一个类对象
    # KanTtsSAMBERT pytorch模型分解为4个小模型
    text_encoder = TextEncoder(config["Model"]["KanTtsSAMBERT"]["params"]).to(device)
    variance_adaptor = VarianceAdaptor(config["Model"]["KanTtsSAMBERT"]["params"]).to(device)
    mel_decoder = MelPNCADecoder(config["Model"]["KanTtsSAMBERT"]["params"]).to(device)
    mel_postnet = PostNet(config["Model"]["KanTtsSAMBERT"]["params"]).to(device)

    logging.info("Loading checkpoint: {}".format(ckpt))
    state_dict = torch.load(ckpt)

    # fsnet.load_state_dict(state_dict["model"], strict=False)
    text_encoder_dict = OrderedDict()  # 有序字典保存的顺序是记录插入的顺序
    variance_adaptor_dict = OrderedDict()
    mel_decoder_dict = OrderedDict()
    mel_postnet_dict = OrderedDict()
    for key, value in state_dict["model"].items():
        # print(key+':'+value)
        if key.startswith("text_encoder") or key.startswith("emo"):
            # text_encoder_dict[key[13:]] = value
            text_encoder_dict[key] = value
        elif key.startswith("variance_adaptor"):
            variance_adaptor_dict[key[17:]] = value
            # variance_adaptor_dict[key] = value
        elif key.startswith("mel_decoder"):
            mel_decoder_dict[key[12:]] = value
            # mel_decoder_dict[key] = value
        else:
            mel_postnet_dict[key[12:]] = value
            # mel_postnet_dict[key] = value

    text_encoder.load_state_dict(text_encoder_dict, strict=False)
    # text_encoder_weights = text_encoder.state_dict()
    # assert text_encoder_weights == text_encoder_dict
    # torch.save(text_encoder_dict, "text_encoder_dict.pt")
    # text_encoder_state_dict = torch.load("text_encoder_dict.pt")
    
    variance_adaptor.load_state_dict(variance_adaptor_dict, strict=False)
    torch.save(variance_adaptor_dict, "sambert_onnx/variance_adaptor_dict.pt")

    mel_decoder.load_state_dict(mel_decoder_dict, strict=False)
    torch.save(mel_decoder_dict, "sambert_onnx/mel_decoder_dict.pt")

    mel_postnet.load_state_dict(mel_postnet_dict, strict=False)

    results_dir = os.path.join(output_dir, "feat")
    os.makedirs(results_dir, exist_ok=True)
    # fsnet.eval()
    text_encoder.eval()
    variance_adaptor.eval()
    mel_decoder.eval()
    mel_postnet.eval()

    i = 0  # pytorch模型转onnx 控制模型运行一次
    with open(sentence, encoding="utf-8") as f:
        for line in f:
            if i > 0:
                break
            i = i+1
            start = time.time()
            line = line.strip().split("\t")
            logging.info("Inference sentence: {}".format(line[0]))
            mel_path = "%s/%s_mel.npy" % (results_dir, line[0])
            dur_path = "%s/%s_dur.txt" % (results_dir, line[0])
            f0_path = "%s/%s_f0.txt" % (results_dir, line[0])
            energy_path = "%s/%s_energy.txt" % (results_dir, line[0])

            with torch.no_grad():
                mel, mel_post, dur, f0, energy = am_synthesis(
                    line[1], text_encoder, variance_adaptor, mel_decoder, mel_postnet, ling_unit, device, se=se
                )

            if nsf_enable:
                mel_post = denorm_f0(mel_post, norm_type=nsf_norm_type, f0_feature=f0_feature) 

            np.save(mel_path, mel_post)
            np.savetxt(dur_path, dur)
            np.savetxt(f0_path, f0)
            np.savetxt(energy_path, energy)
            total_elapsed = time.time() - start
            print(f'AM infer single time: {total_elapsed} seconds')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sentence", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--ckpt", type=str, required=True)
    parser.add_argument("--se_file", type=str, required=False)

    args = parser.parse_args()

    am_infer_divide(args.sentence, args.ckpt, args.output_dir, args.se_file)