from typing import Optional
from functools import wraps

import torch
from torch import Tensor

from megatron.training import get_args
from megatron.core import mpu, InferenceParams, parallel_state
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.multi_token_prediction import (
    roll_tensor,
    MTPLossAutoScaler,
    MTPLossLoggingHelper,
)

from dcu_megatron.core.pipeline_parallel import (
    fine_grained_offloading_set_last_layer,
)


def tie_word_embeddings_state_dict_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        if get_args().schedule_method == "dualpipev":
            return

        fn(*args, **kwargs)

    return wrapper


def get_mtp_num_layers_to_build(config: TransformerConfig, vp_stage: Optional[int] = None) -> int:
    """Get the number of MTP layers to build."""

    args = get_args()
    if args.schedule_method == "dualpipev":
        if mpu.is_pipeline_first_stage(ignore_virtual=True) and not args.dualpipev_first_chunk:
            return config.mtp_num_layers if config.mtp_num_layers else 0
        else:
            return 0

    if mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage):
        return config.mtp_num_layers if config.mtp_num_layers else 0
    else:
        return 0


class MultiTokenPredictionLayer:
    def backward_dw(self):
        self.eh_proj.backward_dw()
        self.transformer_layer.backward_dw()


class MultiTokenPredictionBlock:
    def forward(
        self,
        input_ids: Tensor,
        position_ids: Tensor,
        hidden_states: Tensor,
        attention_mask: Tensor,
        labels: Tensor = None,
        context: Tensor = None,
        context_mask: Tensor = None,
        rotary_pos_emb: Tensor = None,
        rotary_pos_cos: Tensor = None,
        rotary_pos_sin: Tensor = None,
        attention_bias: Tensor = None,
        inference_params: InferenceParams = None,
        packed_seq_params: PackedSeqParams = None,
        sequence_len_offset: Tensor = None,
        extra_block_kwargs: dict = None,
        runtime_gather_output: Optional[bool] = None,
        loss_mask: Optional[Tensor] = None,
        embedding=None,
        output_layer=None,
        output_weight: Optional[torch.Tensor] = None,
        compute_language_model_loss=None,
    ) -> Tensor:
        """
        Perform the forward pass through all of the MTP modules.

        Args:
            hidden_states (Tensor): Hidden states for input token with the shape [s, b, h]
                where s is the sequence length, b is the batch size, and h is the hidden size.
            attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
                self-attention.

        Returns:
            (Tensor): The mtp loss tensor of shape [b, s].
        """
        assert (
            labels is not None
        ), f"labels should not be None for calculating multi token prediction loss."
        if loss_mask is None:
            # if loss_mask is not provided, use all ones as loss_mask
            loss_mask = torch.ones_like(labels)

        hidden_states_main_model = hidden_states
        for layer_number in range(len(self.layers)):
            if self.config.fine_grained_activation_offloading:
                fine_grained_offloading_set_last_layer(layer_number == len(self.layers) - 1)

            # Calc logits for the current Multi-Token Prediction (MTP) layers.
            input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1)
            position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1)
            # embedding
            decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
            # norm, linear projection and transformer
            hidden_states = self.layers[layer_number](
                decoder_input=decoder_input,
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                inference_params=inference_params,
                rotary_pos_emb=rotary_pos_emb,
                rotary_pos_cos=rotary_pos_cos,
                rotary_pos_sin=rotary_pos_sin,
                packed_seq_params=packed_seq_params,
                sequence_len_offset=sequence_len_offset,
                **(extra_block_kwargs or {}),
            )
            # output
            mtp_logits, _ = output_layer(
                hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
            )
            # Calc loss for the current Multi-Token Prediction (MTP) layers.
            labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
            loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
            mtp_loss = compute_language_model_loss(labels, mtp_logits)
            mtp_loss = loss_mask * mtp_loss
            if self.training:
                MTPLossLoggingHelper.save_loss_to_tracker(
                    torch.sum(mtp_loss) / num_tokens,
                    layer_number,
                    self.config.mtp_num_layers,
                    avg_group=parallel_state.get_tensor_and_context_parallel_group(),
                )
            mtp_loss_scale = self.mtp_loss_scaling_factor / self.config.mtp_num_layers
            if self.config.calculate_per_token_loss:
                hidden_states_main_model = MTPLossAutoScaler.apply(
                    hidden_states_main_model, mtp_loss_scale * mtp_loss
                )
            else:
                hidden_states_main_model = MTPLossAutoScaler.apply(
                    hidden_states_main_model, mtp_loss_scale * mtp_loss / num_tokens
                )

        return hidden_states_main_model

    def backward_dw(self):
        for layer in self.layers:
            layer.backward_dw()
