# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import torch from torch import nn class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel: bool = False, config: dict = None): """RMS Normaliation module Args: dim (int): The width of input, i.e. hidden size eps (float): epsilon to use for the norm, default to 1e-6 sequence_parallel (bool): Set to true if sequence parallelism is being used, this marks the weights as needing to be allreduced. """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) setattr(self.weight, 'sequence_parallel', sequence_parallel) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) @torch.compile(mode="max-autotune-no-cudagraphs") def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight from typing import Any, Callable, Dict, Optional, Tuple, Union import lightop # rmsnorm_forward,rmsnorm_backward from functools import partial from megatron.core.utils import is_torch_min_version if is_torch_min_version("2.4.0a0"): custom_fwd = partial(torch.amp.custom_fwd, device_type="cuda") custom_bwd = partial(torch.amp.custom_bwd, device_type="cuda") else: custom_fwd = torch.cuda.amp.custom_fwd custom_bwd = torch.cuda.amp.custom_bwd def print_rank_0(message): """If distributed is initialized, print only on rank 0.""" if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(message, flush=True) else: print(message, flush=True) class _LightopRMSNorm(torch.autograd.Function): """ 使用lightop实现rmsnorm""" @staticmethod # @custom_fwd def forward(ctx, inp: torch.Tensor, weight: torch.Tensor, ln_out: torch.Tensor, eps: float, is_grad_enabled): output = lightop.rmsnorm_forward(inp, weight, ln_out, eps, training=True)# output = (output, weight) # print_rank_0(f"_LightopRMSNorm: output({output[0].shape, output[1].shape}) = lightop.rmsfwd(inp{inp.shape}, weight{weight.shape}, ...)") rsigma = output[1] if is_grad_enabled: ctx.save_for_backward(inp, weight, rsigma) return output[0] @staticmethod # @custom_bwd def backward(ctx, grad_output): inp, weight, rsigma = ctx.saved_tensors dgrad, dgamma = lightop.rmsnorm_backward(grad_output, inp, rsigma, weight) # print_rank_0(f"_LightopRMSNorm: dgrad{dgrad.shape}, dgamma{dgamma.shape} = lightop.rmsbwd(grad_output{grad_output.shape}, inp{inp.shape}, rsigma{rsigma.shape}, weight{weight.shape})") return dgrad, dgamma, None, None, None class LightopRMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """RMS Normaliation module Args: dim (int): The width of input, i.e. hidden size eps (float): epsilon to use for the norm, default to 1e-6 """ super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) # @no_torch_dynamo() # 动态torch._dynamo.disable def forward(self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None): if torch.is_grad_enabled(): fwd_fn = _LightopRMSNorm.apply args = [] else: fwd_fn = _LightopRMSNorm.forward args = [None] ln_out = torch.empty_like(inp, dtype=inp.dtype, memory_format=torch.contiguous_format) args += (inp, self.weight, ln_out, self.eps, torch.is_grad_enabled()) out = fwd_fn(*args) return out