# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import torch from torch import nn import torch._dynamo torch._dynamo.config.suppress_errors = True class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False, config: dict = None): """RMS Normaliation module Args: dim (int): The width of input, i.e. hidden size eps (float): epsilon to use for the norm, default to 1e-6 sequence_parallel (bool): Set to true if sequence parallelism is being used, this marks the weights as needing to be allreduced. """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) setattr(self.weight, 'sequence_parallel', sequence_parallel) @torch.compile(mode="max-autotune-no-cudagraphs") def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) @torch.compile(mode="max-autotune-no-cudagraphs") def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight