import torch

from typing import Any, Callable, Dict, Optional, Tuple, Union
import lightop # rmsnorm_forward,rmsnorm_backward

from functools import partial
from megatron.core.utils import is_torch_min_version
if is_torch_min_version("2.4.0a0"):
    custom_fwd = partial(torch.amp.custom_fwd, device_type="cuda")
    custom_bwd = partial(torch.amp.custom_bwd, device_type="cuda")
else:
    custom_fwd = torch.cuda.amp.custom_fwd
    custom_bwd = torch.cuda.amp.custom_bwd


class _LightopRMSNorm(torch.autograd.Function):
    """ 使用lightop实现rmsnorm"""

    @staticmethod
    # @custom_fwd
    def forward(ctx,
                inp: torch.Tensor,
                weight: torch.Tensor,
                ln_out: torch.Tensor,
                eps: float,
                is_grad_enabled):
        output = lightop.rmsnorm_forward(inp, weight, ln_out, eps, training=True)# output = (output, weight)
        rsigma = output[1]
        if is_grad_enabled:
            ctx.save_for_backward(inp, weight, rsigma)
        return output[0]

    @staticmethod
    # @custom_bwd
    def backward(ctx, grad_output):
        inp, weight, rsigma = ctx.saved_tensors

        dgrad, dgamma = lightop.rmsnorm_backward(grad_output, inp, rsigma, weight)
        return dgrad, dgamma, None, None, None


class LightopRMSNorm(torch.nn.Module):
    def __init__(self,
                 dim: int,
                 eps: float = 1e-6,):
        """RMS Normaliation module

        Args:
            dim (int): The width of input, i.e. hidden size
            eps (float): epsilon to use for the norm, default to 1e-6
        """
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(dim))

    # @no_torch_dynamo() # 动态torch._dynamo.disable
    def forward(self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None):
        if torch.is_grad_enabled():
            fwd_fn = _LightopRMSNorm.apply
            args = []
        else:
            fwd_fn = _LightopRMSNorm.forward
            args = [None]
        ln_out = torch.empty_like(inp, dtype=inp.dtype, memory_format=torch.contiguous_format)
        args += (inp, self.weight, ln_out, self.eps, torch.is_grad_enabled())
        out = fwd_fn(*args)
        return out
