import torch
from torch import nn
import torch._dynamo
torch._dynamo.config.suppress_errors = True

class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6, config=None):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    @torch.compile(mode="max-autotune-no-cudagraphs")
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    @torch.compile(mode="max-autotune-no-cudagraphs")
    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


    # def forward(self, hidden_states):
    #     input_dtype = hidden_states.dtype
    #     hidden_states = hidden_states.to(torch.float32)
    #     variance = hidden_states.pow(2).mean(-1, keepdim=True)
    #     hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    #     return self.weight * hidden_states.to(input_dtype)