import torch
import torch.nn as nn
import torch.nn.functional as F

from kantts.models.sambert import FFTBlock, PNCABlock, Prenet
from kantts.models.sambert.positions import (
    SinusoidalPositionEncoder,
    DurSinusoidalPositionEncoder,
)
from kantts.models.sambert.adaptors import (
    LengthRegulator,
    VarFsmnRnnNARPredictor,
    VarRnnARPredictor,
)
from kantts.models.sambert.fsmn import FsmnEncoderV2
from kantts.models.sambert.alignment import b_mas
from kantts.models.sambert.attention import ConvAttention

from kantts.models.utils import get_mask_from_lengths


class SelfAttentionEncoder(nn.Module):
    def __init__(
        self,
        n_layer,
        d_in,
        d_model,
        n_head,
        d_head,
        d_inner,
        dropout,
        dropout_att,
        dropout_relu,
        position_encoder,
    ):
        super(SelfAttentionEncoder, self).__init__()

        self.d_in = d_in
        self.d_model = d_model
        self.dropout = dropout
        d_in_lst = [d_in] + [d_model] * (n_layer - 1)
        self.fft = nn.ModuleList(
            [
                FFTBlock(
                    d,
                    d_model,
                    n_head,
                    d_head,
                    d_inner,
                    (3, 1),
                    dropout,
                    dropout_att,
                    dropout_relu,
                )
                for d in d_in_lst
            ]
        )
        self.ln = nn.LayerNorm(d_model, eps=1e-6)
        self.position_enc = position_encoder

    def forward(self, input, mask=None, return_attns=False):
        input *= self.d_model ** 0.5
        if isinstance(self.position_enc, SinusoidalPositionEncoder):
            input = self.position_enc(input)
        else:
            raise NotImplementedError

        input = F.dropout(input, p=self.dropout, training=self.training)

        enc_slf_attn_list = []
        max_len = input.size(1)
        if mask is not None:
            slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
        else:
            slf_attn_mask = None

        enc_output = input
        for id, layer in enumerate(self.fft):
            enc_output, enc_slf_attn = layer(
                enc_output, mask=mask, slf_attn_mask=slf_attn_mask
            )
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        enc_output = self.ln(enc_output)

        return enc_output, enc_slf_attn_list


class HybridAttentionDecoder(nn.Module):
    def __init__(
        self,
        d_in,
        prenet_units,
        n_layer,
        d_model,
        d_mem,
        n_head,
        d_head,
        d_inner,
        dropout,
        dropout_att,
        dropout_relu,
        d_out,
    ):
        super(HybridAttentionDecoder, self).__init__()

        self.d_model = d_model
        self.dropout = dropout
        self.prenet = Prenet(d_in, prenet_units, d_model)
        self.dec_in_proj = nn.Linear(d_model + d_mem, d_model)
        self.pnca = nn.ModuleList(
            [
                PNCABlock(
                    d_model,
                    d_mem,
                    n_head,
                    d_head,
                    d_inner,
                    (1, 1),
                    dropout,
                    dropout_att,
                    dropout_relu,
                )
                for _ in range(n_layer)
            ]
        )
        self.ln = nn.LayerNorm(d_model, eps=1e-6)
        self.dec_out_proj = nn.Linear(d_model, d_out)

    def reset_state(self):
        for layer in self.pnca:
            layer.reset_state()

    def get_pnca_attn_mask(
        self, device, max_len, x_band_width, h_band_width, masks=None
    ):
        if masks is not None:
            pnca_attn_mask = masks.unsqueeze(1).expand(-1, max_len, -1)
        else:
            pnca_attn_mask = None

        range_ = torch.arange(max_len).to(device)
        x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :]
        x_end = (range_ + 1)[None, None, :]
        h_start = range_[None, None, :]
        h_end = torch.clamp_max(range_ + h_band_width + 1, max_len + 1)[None, None, :]

        pnca_x_attn_mask = ~(
            (x_start <= range_[None, :, None]) & (x_end > range_[None, :, None])
        ).transpose(1, 2)
        pnca_h_attn_mask = ~(
            (h_start <= range_[None, :, None]) & (h_end > range_[None, :, None])
        ).transpose(1, 2)

        if pnca_attn_mask is not None:
            pnca_x_attn_mask = pnca_x_attn_mask | pnca_attn_mask
            pnca_h_attn_mask = pnca_h_attn_mask | pnca_attn_mask
            pnca_x_attn_mask = pnca_x_attn_mask.masked_fill(
                pnca_attn_mask.transpose(1, 2), False
            )
            pnca_h_attn_mask = pnca_h_attn_mask.masked_fill(
                pnca_attn_mask.transpose(1, 2), False
            )

        return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask

    # must call reset_state before
    def forward(
        self, input, memory, x_band_width, h_band_width, masks=None, return_attns=False
    ):
        input = self.prenet(input)
        input = torch.cat([memory, input], dim=-1)
        input = self.dec_in_proj(input)

        if masks is not None:
            input = input.masked_fill(masks.unsqueeze(-1), 0)

        input *= self.d_model ** 0.5
        input = F.dropout(input, p=self.dropout, training=self.training)

        max_len = input.size(1)
        pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
            input.device, max_len, x_band_width, h_band_width, masks
        )

        dec_pnca_attn_x_list = []
        dec_pnca_attn_h_list = []
        dec_output = input
        for id, layer in enumerate(self.pnca):
            dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
                dec_output,
                memory,
                masks=masks,
                pnca_x_attn_mask=pnca_x_attn_mask,
                pnca_h_attn_mask=pnca_h_attn_mask,
            )
            if return_attns:
                dec_pnca_attn_x_list += [dec_pnca_attn_x]
                dec_pnca_attn_h_list += [dec_pnca_attn_h]

        dec_output = self.ln(dec_output)
        dec_output = self.dec_out_proj(dec_output)

        return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list

    # must call reset_state before when step == 0
    def infer(
        self,
        step,
        input,
        memory,
        x_band_width,
        h_band_width,
        masks=None,
        return_attns=False,
    ):
        max_len = memory.size(1)

        input = self.prenet(input)
        input = torch.cat([memory[:, step : step + 1, :], input], dim=-1)
        input = self.dec_in_proj(input)

        input *= self.d_model ** 0.5
        input = F.dropout(input, p=self.dropout, training=self.training)

        pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
            input.device, max_len, x_band_width, h_band_width, masks
        )

        dec_pnca_attn_x_list = []
        dec_pnca_attn_h_list = []
        dec_output = input
        for id, layer in enumerate(self.pnca):
            if masks is not None:
                mask_step = masks[:, step : step + 1]
            else:
                mask_step = None
            dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
                dec_output,
                memory,
                mask=mask_step,
                pnca_x_attn_mask=pnca_x_attn_mask[:, step : step + 1, : (step + 1)],
                pnca_h_attn_mask=pnca_h_attn_mask[:, step : step + 1, :],
            )
            if return_attns:
                dec_pnca_attn_x_list += [dec_pnca_attn_x]
                dec_pnca_attn_h_list += [dec_pnca_attn_h]

        dec_output = self.ln(dec_output)
        dec_output = self.dec_out_proj(dec_output)

        return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list


class TextFftEncoder(nn.Module):
    def __init__(self, config):
        super(TextFftEncoder, self).__init__()

        d_emb = config["embedding_dim"]
        self.using_byte = False
        if config.get("using_byte", False):
            self.using_byte = True
            nb_ling_byte_index = config["byte_index"]
            self.byte_index_emb = nn.Embedding(nb_ling_byte_index, d_emb)
        else:
            # linguistic unit lookup table
            nb_ling_sy = config["sy"]
            nb_ling_tone = config["tone"]
            nb_ling_syllable_flag = config["syllable_flag"]
            nb_ling_ws = config["word_segment"]
            self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
            self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
            self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
            self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)

        max_len = config["max_len"]

        nb_layers = config["encoder_num_layers"]
        nb_heads = config["encoder_num_heads"]
        d_model = config["encoder_num_units"]
        d_head = d_model // nb_heads
        d_inner = config["encoder_ffn_inner_dim"]
        dropout = config["encoder_dropout"]
        dropout_attn = config["encoder_attention_dropout"]
        dropout_relu = config["encoder_relu_dropout"]
        d_proj = config["encoder_projection_units"]

        self.d_model = d_model

        position_enc = SinusoidalPositionEncoder(max_len, d_emb)

        self.ling_enc = SelfAttentionEncoder(
            nb_layers,
            d_emb,
            d_model,
            nb_heads,
            d_head,
            d_inner,
            dropout,
            dropout_attn,
            dropout_relu,
            position_enc,
        )

        self.ling_proj = nn.Linear(d_model, d_proj, bias=False)

    def forward(self, inputs_ling, masks=None, return_attns=False):
        # Parse inputs_ling_seq
        if self.using_byte:
            inputs_byte_index = inputs_ling[:, :, 0]
            byte_index_embedding = self.byte_index_emb(inputs_byte_index)
            ling_embedding = byte_index_embedding
        else:
            inputs_sy = inputs_ling[:, :, 0]
            inputs_tone = inputs_ling[:, :, 1]
            inputs_syllable_flag = inputs_ling[:, :, 2]
            inputs_ws = inputs_ling[:, :, 3]

            # Lookup table
            sy_embedding = self.sy_emb(inputs_sy)
            tone_embedding = self.tone_emb(inputs_tone)
            syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag)
            ws_embedding = self.ws_emb(inputs_ws)

            ling_embedding = (
                sy_embedding + tone_embedding + syllable_flag_embedding + ws_embedding
            )

        enc_output, enc_slf_attn_lst = self.ling_enc(
            ling_embedding, masks, return_attns
        )

        if hasattr(self, "ling_proj"):
            enc_output = self.ling_proj(enc_output)

        return enc_output, enc_slf_attn_lst, ling_embedding


class TextEncoder(nn.Module):
    def __init__(self, config):
        super(TextEncoder, self).__init__()
        self.text_encoder = TextFftEncoder(config)
        self.se_enable = config.get("SE", False)
        if not self.se_enable:
            self.spk_tokenizer = nn.Embedding(config["speaker"], config["speaker_units"])
        self.emo_tokenizer = nn.Embedding(config["emotion"], config["emotion_units"])
        # self.variance_adaptor = VarianceAdaptor(config)
        # self.mel_decoder = MelPNCADecoder(config)
        # self.mel_postnet = PostNet(config)
        self.MAS = False
        if config.get("MAS", False):
            self.MAS = True
            self.align_attention = ConvAttention(
                n_mel_channels=config["num_mels"],
                n_text_channels=config["embedding_dim"],
                n_att_channels=config["num_mels"],
            )
        self.fp_enable = config.get("FP", False)
        if self.fp_enable:
            self.FP_predictor = FP_Predictor(config)

    def forward(self, inputs_ling, inputs_emotion, inputs_speaker, inputs_ling_masks=None, return_attns=False):
        text_hid, enc_sla_attn_lst, ling_embedding = self.text_encoder(
            inputs_ling, inputs_ling_masks, return_attns
        )
        emo_hid = self.emo_tokenizer(inputs_emotion)
        spk_hid = inputs_speaker if self.se_enable else self.spk_tokenizer(inputs_speaker)

        if return_attns:
            return text_hid, enc_sla_attn_lst, ling_embedding, emo_hid, spk_hid
        else:
            return text_hid, ling_embedding, emo_hid, spk_hid


class VarianceAdaptor(nn.Module):
    def __init__(self, config):
        super(VarianceAdaptor, self).__init__()

        input_dim = (
            config["encoder_projection_units"]
            + config["emotion_units"]
            + config["speaker_units"]
        )
        filter_size = config["predictor_filter_size"]
        fsmn_num_layers = config["predictor_fsmn_num_layers"]
        num_memory_units = config["predictor_num_memory_units"]
        ffn_inner_dim = config["predictor_ffn_inner_dim"]
        dropout = config["predictor_dropout"]
        shift = config["predictor_shift"]
        lstm_units = config["predictor_lstm_units"]

        dur_pred_prenet_units = config["dur_pred_prenet_units"]
        dur_pred_lstm_units = config["dur_pred_lstm_units"]

        self.pitch_predictor = VarFsmnRnnNARPredictor(
            input_dim,
            filter_size,
            fsmn_num_layers,
            num_memory_units,
            ffn_inner_dim,
            dropout,
            shift,
            lstm_units,
        )
        self.energy_predictor = VarFsmnRnnNARPredictor(
            input_dim,
            filter_size,
            fsmn_num_layers,
            num_memory_units,
            ffn_inner_dim,
            dropout,
            shift,
            lstm_units,
        )
        self.duration_predictor = VarRnnARPredictor(
            input_dim, dur_pred_prenet_units, dur_pred_lstm_units
        )

        self.length_regulator = LengthRegulator(config["outputs_per_step"])
        self.dur_position_encoder = DurSinusoidalPositionEncoder(
            config["encoder_projection_units"], config["outputs_per_step"]
        )

        self.pitch_emb = nn.Conv1d(
            1, config["encoder_projection_units"], kernel_size=9, padding=4
        )
        self.energy_emb = nn.Conv1d(
            1, config["encoder_projection_units"], kernel_size=9, padding=4
        )

    def forward(
        self,
        inputs_text_embedding,
        inputs_emo_embedding,
        inputs_spk_embedding,  # [1,20,192]
        masks=None,
        output_masks=None,
        duration_targets=None,
        pitch_targets=None,
        energy_targets=None,
    ):

        batch_size = inputs_text_embedding.size(0)

        variance_predictor_inputs = torch.cat(
            [inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding], dim=-1
        )

        pitch_predictions = self.pitch_predictor(variance_predictor_inputs, masks)
        energy_predictions = self.energy_predictor(variance_predictor_inputs, masks)

        if pitch_targets is not None:
            pitch_embeddings = self.pitch_emb(pitch_targets.unsqueeze(1)).transpose(
                1, 2
            )
        else:
            pitch_embeddings = self.pitch_emb(pitch_predictions.unsqueeze(1)).transpose(
                1, 2
            )

        if energy_targets is not None:
            energy_embeddings = self.energy_emb(energy_targets.unsqueeze(1)).transpose(
                1, 2
            )
        else:
            energy_embeddings = self.energy_emb(energy_predictions.unsqueeze(1)).transpose(
                1, 2)

        inputs_text_embedding_aug = (
            inputs_text_embedding + pitch_embeddings + energy_embeddings
        )
        duration_predictor_cond = torch.cat(
            [inputs_text_embedding_aug, inputs_spk_embedding, inputs_emo_embedding],
            dim=-1,
        )
        if duration_targets is not None:
            duration_predictor_go_frame = torch.zeros(batch_size, 1).to(
                inputs_text_embedding.device
            )
            duration_predictor_input = torch.cat(
                [duration_predictor_go_frame, duration_targets[:, :-1].float()], dim=-1
            )
            duration_predictor_input = torch.log(duration_predictor_input + 1)
            log_duration_predictions, _ = self.duration_predictor(
                duration_predictor_input.unsqueeze(-1),
                duration_predictor_cond,
                masks=masks,
            )
            duration_predictions = torch.exp(log_duration_predictions) - 1
        else:
            log_duration_predictions = self.duration_predictor.infer(
                duration_predictor_cond, masks=masks
            )
            duration_predictions = torch.exp(log_duration_predictions) - 1

        if duration_targets is not None:
            LR_text_outputs, LR_length_rounded = self.length_regulator(
                inputs_text_embedding_aug, duration_targets, masks=output_masks
            )
            LR_position_embeddings = self.dur_position_encoder(
                duration_targets, masks=output_masks
            )
            LR_emo_outputs, _ = self.length_regulator(
                inputs_emo_embedding, duration_targets, masks=output_masks
            )
            LR_spk_outputs, _ = self.length_regulator(
                inputs_spk_embedding, duration_targets, masks=output_masks
            )
        else:
            LR_text_outputs, LR_length_rounded = self.length_regulator(
                inputs_text_embedding_aug, duration_predictions, masks=output_masks
            )
            LR_position_embeddings = self.dur_position_encoder(
                duration_predictions, masks=output_masks
            )
            LR_emo_outputs, _ = self.length_regulator(
                inputs_emo_embedding, duration_predictions, masks=output_masks
            )
            LR_spk_outputs, _ = self.length_regulator(
                inputs_spk_embedding, duration_predictions, masks=output_masks
            )

        LR_text_outputs = LR_text_outputs + LR_position_embeddings

        return (
            LR_text_outputs,
            LR_emo_outputs,
            LR_spk_outputs,  # [1,153,192]
            LR_length_rounded,
            log_duration_predictions,
            pitch_predictions,
            energy_predictions,
        )


class VarianceAdaptor2(nn.Module):
    def __init__(self, config):
        super(VarianceAdaptor2, self).__init__()

        input_dim = (
            config["encoder_projection_units"]
            + config["emotion_units"]
            + config["speaker_units"]
        )
        filter_size = config["predictor_filter_size"]
        fsmn_num_layers = config["predictor_fsmn_num_layers"]
        num_memory_units = config["predictor_num_memory_units"]
        ffn_inner_dim = config["predictor_ffn_inner_dim"]
        dropout = config["predictor_dropout"]
        shift = config["predictor_shift"]
        lstm_units = config["predictor_lstm_units"]

        dur_pred_prenet_units = config["dur_pred_prenet_units"]
        dur_pred_lstm_units = config["dur_pred_lstm_units"]

        self.pitch_predictor = VarFsmnRnnNARPredictor(
            input_dim,
            filter_size,
            fsmn_num_layers,
            num_memory_units,
            ffn_inner_dim,
            dropout,
            shift,
            lstm_units,
        )
        self.energy_predictor = VarFsmnRnnNARPredictor(
            input_dim,
            filter_size,
            fsmn_num_layers,
            num_memory_units,
            ffn_inner_dim,
            dropout,
            shift,
            lstm_units,
        )
        self.duration_predictor = VarRnnARPredictor(
            input_dim, dur_pred_prenet_units, dur_pred_lstm_units
        )

        self.length_regulator = LengthRegulator(config["outputs_per_step"])
        self.dur_position_encoder = DurSinusoidalPositionEncoder(
            config["encoder_projection_units"], config["outputs_per_step"]
        )

        self.pitch_emb = nn.Conv1d(
            1, config["encoder_projection_units"], kernel_size=9, padding=4
        )
        self.energy_emb = nn.Conv1d(
            1, config["encoder_projection_units"], kernel_size=9, padding=4
        )

    def forward(
        self,
        inputs_text_embedding,
        inputs_emo_embedding,
        inputs_spk_embedding,  # [1,20,192]
        scale=1.0,
        masks=None,
        output_masks=None,
        duration_targets=None,
        pitch_targets=None,
        energy_targets=None,
    ):
        batch_size = inputs_text_embedding.size(0)

        variance_predictor_inputs = torch.cat(
            [inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding], dim=-1
        )

        pitch_predictions = self.pitch_predictor(variance_predictor_inputs, masks)
        energy_predictions = self.energy_predictor(variance_predictor_inputs, masks)

        if pitch_targets is not None:
            pitch_embeddings = self.pitch_emb(pitch_targets.unsqueeze(1)).transpose(
                1, 2
            )
        else:
            pitch_embeddings = self.pitch_emb(pitch_predictions.unsqueeze(1)).transpose(
                1, 2
            )

        if energy_targets is not None:
            energy_embeddings = self.energy_emb(energy_targets.unsqueeze(1)).transpose(
                1, 2
            )
        else:
            energy_embeddings = self.energy_emb(energy_predictions.unsqueeze(1)).transpose(
                1, 2)

        inputs_text_embedding_aug = (
            inputs_text_embedding + pitch_embeddings + energy_embeddings
        )
        duration_predictor_cond = torch.cat(
            [inputs_text_embedding_aug, inputs_spk_embedding, inputs_emo_embedding],
            dim=-1,
        )
        if duration_targets is not None:
            duration_predictor_go_frame = torch.zeros(batch_size, 1).to(
                inputs_text_embedding.device
            )
            duration_predictor_input = torch.cat(
                [duration_predictor_go_frame, duration_targets[:, :-1].float()], dim=-1
            )
            duration_predictor_input = torch.log(duration_predictor_input + 1)
            log_duration_predictions, _ = self.duration_predictor(
                duration_predictor_input.unsqueeze(-1),
                duration_predictor_cond,
                masks=masks,
            )
            duration_predictions = torch.exp(log_duration_predictions) - 1
        else:
            log_duration_predictions = self.duration_predictor.infer(
                duration_predictor_cond, masks=masks
            )
            duration_predictions = torch.exp(log_duration_predictions) - 1

        if duration_targets is not None:
            LR_text_outputs, LR_length_rounded = self.length_regulator(
                inputs_text_embedding_aug, duration_targets*scale, masks=output_masks  # *scale
            )
            LR_position_embeddings = self.dur_position_encoder(
                duration_targets, masks=output_masks
            )
            LR_emo_outputs, _ = self.length_regulator(
                inputs_emo_embedding, duration_targets*scale, masks=output_masks  # *scale
            )
            LR_spk_outputs, _ = self.length_regulator(
                inputs_spk_embedding, duration_targets*scale, masks=output_masks  # *scale
            )
        else:
            LR_text_outputs, LR_length_rounded = self.length_regulator(
                inputs_text_embedding_aug, duration_predictions*scale, masks=output_masks # *scale
            )
            LR_position_embeddings = self.dur_position_encoder(
                duration_predictions*scale, masks=output_masks # *target_rate
            )
            LR_emo_outputs, _ = self.length_regulator(
                inputs_emo_embedding, duration_predictions*scale, masks=output_masks # *scale
            )
            LR_spk_outputs, _ = self.length_regulator(
                inputs_spk_embedding, duration_predictions*scale, masks=output_masks # *scale
            )

        LR_text_outputs = LR_text_outputs + LR_position_embeddings

        return (
            LR_text_outputs,
            LR_emo_outputs,
            LR_spk_outputs,  # [1,153,192]
            LR_length_rounded,
            log_duration_predictions,
            pitch_predictions,
            energy_predictions,
        )


class MelPNCADecoder(nn.Module):
    def __init__(self, config):
        super(MelPNCADecoder, self).__init__()

        prenet_units = config["decoder_prenet_units"]
        nb_layers = config["decoder_num_layers"]
        nb_heads = config["decoder_num_heads"]
        d_model = config["decoder_num_units"]
        d_head = d_model // nb_heads
        d_inner = config["decoder_ffn_inner_dim"]
        dropout = config["decoder_dropout"]
        dropout_attn = config["decoder_attention_dropout"]
        dropout_relu = config["decoder_relu_dropout"]
        outputs_per_step = config["outputs_per_step"]

        d_mem = (
            config["encoder_projection_units"] * outputs_per_step
            + config["emotion_units"]
            + config["speaker_units"]
        )
        d_mel = config["num_mels"]

        self.d_mel = d_mel
        self.r = outputs_per_step
        self.nb_layers = nb_layers

        self.mel_dec = HybridAttentionDecoder(
            d_mel,
            prenet_units,
            nb_layers,
            d_model,
            d_mem,
            nb_heads,
            d_head,
            d_inner,
            dropout,
            dropout_attn,
            dropout_relu,
            d_mel * outputs_per_step,
        )

    def forward(
        self,
        memory,
        x_band_width,
        h_band_width,
        target=None,
        masks=None,
        return_attns=False,
    ):
        batch_size = memory.size(0)
        go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device)

        if target is not None:
            self.mel_dec.reset_state()
            input = target[:, self.r - 1 :: self.r, :]
            input = torch.cat([go_frame, input], dim=1)[:, :-1, :]
            dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec(
                input,
                memory,
                x_band_width,
                h_band_width,
                masks=masks,
                return_attns=return_attns,
            )

        else:
            dec_output = []
            dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)]
            dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)]
            self.mel_dec.reset_state()
            input = go_frame
            for step in range(memory.size(1)):
                (
                    dec_output_step,
                    dec_pnca_attn_x_step,
                    dec_pnca_attn_h_step,
                ) = self.mel_dec.infer(
                    step,
                    input,
                    memory,
                    x_band_width,
                    h_band_width,
                    masks=masks,
                    return_attns=return_attns,
                )
                input = dec_output_step[:, :, -self.d_mel :]

                dec_output.append(dec_output_step)
                for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
                    zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)
                ):
                    left = memory.size(1) - pnca_x_attn.size(-1)
                    if left > 0:
                        padding = torch.zeros((pnca_x_attn.size(0), 1, left)).to(
                            pnca_x_attn
                        )
                        pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1)
                    dec_pnca_attn_x_list[layer_id].append(pnca_x_attn)
                    dec_pnca_attn_h_list[layer_id].append(pnca_h_attn)
            dec_output = torch.cat(dec_output, dim=1)
            if return_attns:
                for layer_id in range(self.nb_layers):
                    dec_pnca_attn_x_list[layer_id] = torch.cat(
                        dec_pnca_attn_x_list[layer_id], dim=1
                    )
                    dec_pnca_attn_h_list[layer_id] = torch.cat(
                        dec_pnca_attn_h_list[layer_id], dim=1
                    )

        if return_attns:
            return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
        else:
            return dec_output


class PostNet(nn.Module):
    def __init__(self, config):
        super(PostNet, self).__init__()

        self.filter_size = config["postnet_filter_size"]
        self.fsmn_num_layers = config["postnet_fsmn_num_layers"]
        self.num_memory_units = config["postnet_num_memory_units"]
        self.ffn_inner_dim = config["postnet_ffn_inner_dim"]
        self.dropout = config["postnet_dropout"]
        self.shift = config["postnet_shift"]
        self.lstm_units = config["postnet_lstm_units"]
        self.num_mels = config["num_mels"]

        self.fsmn = FsmnEncoderV2(
            self.filter_size,
            self.fsmn_num_layers,
            self.num_mels,
            self.num_memory_units,
            self.ffn_inner_dim,
            self.dropout,
            self.shift,
        )
        self.lstm = nn.LSTM(
            self.num_memory_units, self.lstm_units, num_layers=1, batch_first=True
        )
        self.fc = nn.Linear(self.lstm_units, self.num_mels)

    def forward(self, x, mask=None):
        postnet_fsmn_output = self.fsmn(x, mask)
        # The input can also be a packed variable length sequence,
        # here we just omit it for simpliciy due to the mask and uni-directional lstm.
        postnet_lstm_output, _ = self.lstm(postnet_fsmn_output)
        mel_residual_output = self.fc(postnet_lstm_output)

        return mel_residual_output


class FP_Predictor(nn.Module):
    def __init__(self, config):
        super(FP_Predictor, self).__init__()

        self.w_1 = nn.Conv1d(
            config["encoder_projection_units"],
            config["embedding_dim"] // 2,
            kernel_size=3,
            padding=1,
        )
        self.w_2 = nn.Conv1d(
            config["embedding_dim"] // 2,
            config["encoder_projection_units"],
            kernel_size=1,
            padding=0,
        )
        self.layer_norm1 = nn.LayerNorm(config["embedding_dim"] // 2, eps=1e-6)
        self.layer_norm2 = nn.LayerNorm(config["encoder_projection_units"], eps=1e-6)
        self.dropout_inner = nn.Dropout(0.1)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(config["encoder_projection_units"], 4)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = F.relu(self.w_1(x))
        x = x.transpose(1, 2)
        x = self.dropout_inner(self.layer_norm1(x))
        x = x.transpose(1, 2)
        x = F.relu(self.w_2(x))
        x = x.transpose(1, 2)
        x = self.dropout(self.layer_norm2(x))
        output = F.softmax(self.fc(x), dim=2)
        return output