import torch

from functools import wraps

from dcu_megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler

def forward_step_wrapper(fn):
    @wraps(fn)
    def wrapper(
        forward_step_func,
        data_iterator,
        model,
        num_microbatches,
        input_tensor,
        forward_data_store,
        config,
        **kwargs,
    ):
        output, num_tokens = fn(
            forward_step_func,
            data_iterator,
            model,
            num_microbatches,
            input_tensor,
            forward_data_store,
            config,
            **kwargs
        )

        if not isinstance(input_tensor, list):
            # unwrap_output_tensor True
            output_tensor = output
        else:
            output_tensor = output[0]

        # Set the loss scale for Multi-Token Prediction (MTP) loss.
        if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
            # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
            loss_scale = (
                config.grad_scale_func(torch.ones(1, device=output_tensor.device))
                if config.grad_scale_func is not None
                else torch.ones(1, device=output_tensor.device)
            )
            # Set the loss scale
            if config.calculate_per_token_loss:
                MTPLossAutoScaler.set_loss_scale(loss_scale)
            else:
                MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
        return output, num_tokens

    return wrapper
    