from typing import Any, Optional
from functools import partial

import torch
from torch import Tensor

from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import make_viewless_tensor
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from megatron.core.transformer.transformer_config import TransformerConfig

from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables


def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[int] = None):
    """Get the index offset of current pipeline stage, given the level of pipelining."""
    args = get_args()
    pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
    if not parallel_state.is_inside_encoder():
        pp_decoder_start = parallel_state.get_pipeline_model_parallel_decoder_start()
        if pp_decoder_start is not None:
            pipeline_rank = pipeline_rank - pp_decoder_start

    if config.pipeline_model_parallel_size > 1:

        if (
            config.num_layers_in_first_pipeline_stage is not None
            or config.num_layers_in_last_pipeline_stage is not None
        ):
            # Calculate number of pipeline stages to distribute the remaining Transformer
            # layers after deducting the Transformer layers in the first or the last stages
            middle_pipeline_stages = config.pipeline_model_parallel_size
            if args.schedule_method == 'dualpipev':
                middle_pipeline_stages *= 2

            middle_pipeline_stages -= sum(
                [
                    1 if x is not None else 0
                    for x in (
                        config.num_layers_in_first_pipeline_stage,
                        config.num_layers_in_last_pipeline_stage,
                    )
                ]
            )

            # Calculate layers to distribute in each pipeline stage. If the
            # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
            # are not set, we will not enable uneven pipeline. All layers will be treated
            # as middle layers.
            num_layers_in_first_pipeline_stage = (
                0
                if config.num_layers_in_first_pipeline_stage is None
                else config.num_layers_in_first_pipeline_stage
            )
            num_layers_in_last_pipeline_stage = (
                0
                if config.num_layers_in_last_pipeline_stage is None
                else config.num_layers_in_last_pipeline_stage
            )

            middle_num_layers = (
                config.num_layers
                - num_layers_in_first_pipeline_stage
                - num_layers_in_last_pipeline_stage
            )

            if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
                assert (
                    vp_stage is not None
                ), "vp_stage must be provided if virtual pipeline model parallel size is set"

                # Calculate number of layers in each virtual model chunk
                # If the num_layers_in_first_pipeline_stage and
                # num_layers_in_last_pipeline_stage are not set, all pipeline stages
                # will be treated as middle pipeline stages in the calculation
                num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (
                    0
                    if config.num_layers_in_first_pipeline_stage is None
                    else config.num_layers_in_first_pipeline_stage // vp_size
                )

                num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (
                    0
                    if config.num_layers_in_last_pipeline_stage is None
                    else config.num_layers_in_last_pipeline_stage // vp_size
                )

                num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = (
                    middle_num_layers // vp_size
                )

                # First stage + middle stage + last stage
                total_virtual_chunks = (
                    num_layers_per_virtual_model_chunk_in_first_pipeline_stage
                    + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
                    + num_layers_per_virtual_model_chunk_in_last_pipeline_stage
                )

                # Calculate the layer offset with interleaved uneven pipeline parallelism
                if pipeline_rank == 0:
                    offset = vp_stage * total_virtual_chunks
                else:
                    offset = (
                        vp_stage * total_virtual_chunks
                        + num_layers_per_virtual_model_chunk_in_first_pipeline_stage
                        + (pipeline_rank - 1)
                        * (
                            num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
                            // middle_pipeline_stages
                        )
                    )
            else:
                if middle_pipeline_stages > 0:
                    num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages
                else:
                    num_layers_per_pipeline_rank = 0

                middle_pipeline_rank = (
                    pipeline_rank
                    if config.num_layers_in_first_pipeline_stage is None
                    else pipeline_rank - 1
                )

                if not getattr(args, 'dualpipev_first_chunk', True):
                    middle_pipeline_rank = (
                        config.pipeline_model_parallel_size
                        if config.num_layers_in_first_pipeline_stage is None
                        else config.pipeline_model_parallel_size - 1
                    ) + (config.pipeline_model_parallel_size - (pipeline_rank + 1))

                if getattr(args, 'dualpipev_first_chunk', True) and pipeline_rank == 0:
                        offset = 0
                else:
                    offset = (
                        middle_pipeline_rank * num_layers_per_pipeline_rank
                    ) + num_layers_in_first_pipeline_stage
        else:
            num_layers = config.num_layers

            # Increase the number of layers by one if we include the embedding (loss)
            # layer into pipeline parallelism partition and placement
            if config.account_for_embedding_in_pipeline_split:
                num_layers += 1

            if config.account_for_loss_in_pipeline_split:
                num_layers += 1

            num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
            if args.schedule_method == 'dualpipev':
                num_layers_per_pipeline_rank = num_layers_per_pipeline_rank // 2

            if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
                assert (
                    vp_stage is not None
                ), "vp_stage must be provided if virtual pipeline model parallel size is set"

                num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
                total_virtual_chunks = num_layers // vp_size
                offset = vp_stage * total_virtual_chunks + (
                    pipeline_rank * num_layers_per_virtual_rank
                )

                # Reduce the offset of embedding layer from the total layer number
                if (
                    config.account_for_embedding_in_pipeline_split
                    and not parallel_state.is_pipeline_first_stage(
                        ignore_virtual=False, vp_stage=vp_stage
                    )
                ):
                    offset -= 1
            else:
                if getattr(args, 'dualpipev_first_chunk', True):
                    offset = pipeline_rank * num_layers_per_pipeline_rank
                else:
                    offset = num_layers - (pipeline_rank + 1) * num_layers_per_pipeline_rank

                # Reduce the offset of embedding layer from the total layer number
                if (
                    config.account_for_embedding_in_pipeline_split
                    and not parallel_state.is_pipeline_first_stage(
                        ignore_virtual=False, vp_stage=vp_stage
                    )
                ):
                    offset -= 1
    else:
        offset = 0
    return offset


class TransformerLayer(MegatronCoreTransformerLayer):
    def _submodule_attn_router_forward(
        self,
        hidden_states: Tensor,
        attention_mask: Optional[Tensor] = None,
        context: Optional[Tensor] = None,
        context_mask: Optional[Tensor] = None,
        rotary_pos_emb: Optional[Tensor] = None,
        rotary_pos_cos: Optional[Tensor] = None,
        rotary_pos_sin: Optional[Tensor] = None,
        attention_bias: Optional[Tensor] = None,
        inference_context: Optional[Any] = None,
        packed_seq_params: Optional[PackedSeqParams] = None,
        sequence_len_offset: Optional[Tensor] = None,
        *,
        inference_params: Optional[Any] = None,
    ):
        """
        Performs a combined forward pass that includes self-attention and MLP routing logic.
        """
        pre_mlp_layernorm_output, residual, context = self._forward_attention(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            context=context,
            context_mask=context_mask,
            rotary_pos_emb=rotary_pos_emb,
            rotary_pos_cos=rotary_pos_cos,
            rotary_pos_sin=rotary_pos_sin,
            attention_bias=attention_bias,
            inference_context=inference_context,
            packed_seq_params=packed_seq_params,
            sequence_len_offset=sequence_len_offset,
            inference_params=inference_params,
        )

        probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
        tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
            pre_mlp_layernorm_output, probs, routing_map
        )

        return (tokens_per_expert, permutated_local_input_tokens, permuted_probs, pre_mlp_layernorm_output, residual, context)

    def _submodule_attn_router_postprocess(
        self,
        node,
        tokens_per_expert,
        permutated_local_input_tokens,
        permuted_probs,
        pre_mlp_layernorm_output,
        residual,
        context,
    ):
        node.common_state.tokens_per_expert = tokens_per_expert
        node.common_state.residual = node.detach(residual)
        if self.mlp.use_shared_expert:
            node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)

        return permutated_local_input_tokens, permuted_probs

    def _submodule_dispatch_forward(self, permutated_local_input_tokens, permuted_probs, state=None):
        """
        Dispatches tokens to the appropriate experts based on the router output.
        """
        tokens_per_expert = state.tokens_per_expert
        token_dispatcher = self.mlp.token_dispatcher
        tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)

        return tokens_per_expert, global_input_tokens, global_probs

    def _submodule_dispatch_postprocess(self, node, tokens_per_expert, global_input_tokens, global_probs):
        node.common_state.tokens_per_expert = tokens_per_expert
        return global_input_tokens, global_probs

    def _submodule_moe_forward(self, global_input_tokens, global_probs, state=None):
        """
        Performs a forward pass for the MLP submodule, including both expert-based
        and optional shared-expert computations.
        """
        tokens_per_expert = state.tokens_per_expert
        shared_expert_output = None
        token_dispatcher = self.mlp.token_dispatcher

        dispatched_input, tokens_per_expert, permuted_probs = token_dispatcher.dispatch_postprocess(
            tokens_per_expert, global_input_tokens, global_probs
        )

        expert_output, mlp_bias = self.mlp.experts(
            dispatched_input, tokens_per_expert, permuted_probs
        )
        assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
        if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
            assert state is not None
            shared_expert_output = self.mlp.shared_experts(state.pre_mlp_layernorm_output)

        expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
        return expert_output, shared_expert_output, mlp_bias

    def _submodule_mlp_postprocess(self, node, expert_output, shared_expert_output, mlp_bias):
        assert mlp_bias is None
        node.common_state.pre_mlp_layernorm_output = None
        if shared_expert_output is None:
            return expert_output
        return expert_output, shared_expert_output

    def _submodule_combine_forward(self, expert_output, shared_expert_output=None, state=None):
        residual = state.residual
        token_dispatcher = self.mlp.token_dispatcher
        permutated_local_input_tokens = token_dispatcher.combine_all_to_all(expert_output)
        output = token_dispatcher.combine_postprocess(permutated_local_input_tokens)
        if shared_expert_output is not None:
            output = output + shared_expert_output

        mlp_output_with_bias = (output, None)
        with self.bias_dropout_add_exec_handler():
            hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
                mlp_output_with_bias, residual, self.hidden_dropout
            )
        output = make_viewless_tensor(
            inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
        )

        return output

    def _submodule_combine_postprocess(self, node, output):
        cur_stream = torch.cuda.current_stream()
        node.common_state.residual.record_stream(cur_stream)
        node.common_state.residual = None
        return output

    def _submodule_attn_router_dw(self):
        self.self_attention.backward_dw()

    def _submodule_mlp_dw(self):
        self.mlp.backward_dw()

    def _submodule_attn_postprocess(self, node, pre_mlp_layernorm_output, residual, context):
        return pre_mlp_layernorm_output, residual

    def _submodule_dense_postprocess(self, node, hidden_states):
        return hidden_states

    def _submodule_not_implemented(self, *args):
        raise NotImplementedError("This callable is not implemented.")

    def get_submodule_callables(self, chunk_state):
        """
        The forward callables take 2 parts of inputs:
        1. The ScheduleNode object.
        2. The input tensors.
        """
        from megatron.core.transformer.moe.moe_layer import MoELayer
        from megatron.core.transformer.moe.token_dispatcher import MoEFlexTokenDispatcher

        self.is_moe = isinstance(self.mlp, MoELayer)
        self.is_deepep = False
        if self.is_moe:
            self.is_deepep = isinstance(self.mlp.token_dispatcher, MoEFlexTokenDispatcher)

        def get_func_with_default(func, default_func):
            if self.is_moe:
                return func
            return default_func

        def callable_wrapper(forward_func, postprocess_func, node, *args):
            state = getattr(node, 'common_state', None)
            callable_outputs = forward_func(*args, state=state)
            if isinstance(callable_outputs, tuple):
                outputs = postprocess_func(node, *callable_outputs)
            else:
                outputs = postprocess_func(node, callable_outputs)
            return outputs

        attn_func = get_func_with_default(
            self._submodule_attn_router_forward, self._forward_attention
        )

        def attn_wrapper(hidden_states, state=None):
            """
                state (Any, optional): Placeholder for submodule callable wrapper.
            """
            return attn_func(
                hidden_states=hidden_states,
                attention_mask=chunk_state.attention_mask,
                context=chunk_state.context,
                context_mask=chunk_state.context_mask,
                rotary_pos_emb=chunk_state.rotary_pos_emb,
                rotary_pos_cos=chunk_state.rotary_pos_cos,
                rotary_pos_sin=chunk_state.rotary_pos_sin,
                attention_bias=chunk_state.attention_bias,
                inference_context=chunk_state.inference_context,
                packed_seq_params=chunk_state.packed_seq_params,
                sequence_len_offset=chunk_state.sequence_len_offset,
                inference_params=chunk_state.inference_params,
            )

        attn_postprocess_func = get_func_with_default(
            self._submodule_attn_router_postprocess, self._submodule_attn_postprocess
        )

        dispatch_func = get_func_with_default(
            self._submodule_dispatch_forward, self._submodule_not_implemented
        )
        dispatch_postprocess_func = get_func_with_default(
            self._submodule_dispatch_postprocess, self._submodule_not_implemented
        )

        mlp_func = get_func_with_default(self._submodule_moe_forward, self._forward_mlp)
        mlp_postprocess_func = get_func_with_default(
            self._submodule_mlp_postprocess, self._submodule_dense_postprocess
        )

        combine_func = get_func_with_default(
            self._submodule_combine_forward, self._submodule_not_implemented
        )
        combine_postprocess_func = get_func_with_default(
            self._submodule_combine_postprocess, self._submodule_not_implemented
        )

        attn_forward = partial(callable_wrapper, attn_wrapper, attn_postprocess_func)
        dispatch_forward = partial(callable_wrapper, dispatch_func, dispatch_postprocess_func)
        mlp_forward = partial(callable_wrapper, mlp_func, mlp_postprocess_func)
        combine_forward = partial(callable_wrapper, combine_func, combine_postprocess_func)

        callables = TransformerLayerSubmoduleCallables(
            attention=SubmoduleCallables(forward=attn_forward, dw=self._submodule_attn_router_dw),
            dispatch=SubmoduleCallables(forward=dispatch_forward),
            mlp=SubmoduleCallables(forward=mlp_forward, dw=self._submodule_mlp_dw),
            combine=SubmoduleCallables(forward=combine_forward),
            is_moe=self.is_moe,
            is_deepep=self.is_deepep,
        )
        return callables
