import torch import torch.nn as nn from lightop import RMSNorm from lightop import LayerNorm # class RMSNorm(nn.Module): # def __init__( # self, # dim: int, # elementwise_affine=True, # eps: float = 1e-6, # device=None, # dtype=None, # ): # """ # Initialize the RMSNorm normalization layer. # Args: # dim (int): The dimension of the input tensor. # eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. # Attributes: # eps (float): A small value added to the denominator for numerical stability. # weight (nn.Parameter): Learnable scaling parameter. # """ # factory_kwargs = {"device": device, "dtype": dtype} # super().__init__() # self.eps = eps # if elementwise_affine: # self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) # def _norm(self, x): # """ # Apply the RMSNorm normalization to the input tensor. # Args: # x (torch.Tensor): The input tensor. # Returns: # torch.Tensor: The normalized tensor. # """ # return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # def forward(self, x): # """ # Forward pass through the RMSNorm layer. # Args: # x (torch.Tensor): The input tensor. # Returns: # torch.Tensor: The output tensor after applying RMSNorm. # """ # output = self._norm(x.float()).type_as(x) # if hasattr(self, "weight"): # output = output * self.weight # return output def get_norm_layer(norm_layer): """ Get the normalization layer. Args: norm_layer (str): The type of normalization layer. Returns: norm_layer (nn.Module): The normalization layer. """ if norm_layer == "layer": return LayerNorm elif norm_layer == "rms": return RMSNorm else: raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")