Commit 26940c4c authored by dongcl's avatar dongcl
Browse files

modify rms_norm.py

parent 763941b5
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from torch import nn from typing import Optional
import lightop
class RMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6,
sequence_parallel: bool = False,
config: dict = None):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
setattr(self.weight, 'sequence_parallel', sequence_parallel)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
import torch
from typing import Any, Callable, Dict, Optional, Tuple, Union
import lightop # rmsnorm_forward,rmsnorm_backward
from functools import partial from functools import partial
from megatron.core.utils import is_torch_min_version from megatron.core.utils import is_torch_min_version
......
from megatron.training import get_args from megatron.training import get_args
from megatron.legacy.model import LayerNorm from megatron.legacy.model import LayerNorm, RMSNorm
from .rms_norm import RMSNorm, LightopRMSNorm from .rms_norm import LightopRMSNorm
def get_norm(config): def get_norm(config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment