Commit 1e498ef0 authored by wangxj's avatar wangxj
Browse files

修改lightop的rmsnorm

parent c1200c81
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from torch import nn
class RMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6,
sequence_parallel: bool = False,
config: dict = None):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
setattr(self.weight, 'sequence_parallel', sequence_parallel)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
import torch import torch
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
......
from megatron.training import get_args from megatron.training import get_args
from megatron.legacy.model import LayerNorm from megatron.legacy.model import LayerNorm
from .rms_norm import LightopRMSNorm from .rms_norm import RMSNorm, LightopRMSNorm
def get_norm(config): def get_norm(config):
...@@ -16,6 +16,10 @@ def get_norm(config): ...@@ -16,6 +16,10 @@ def get_norm(config):
if args.apply_layernorm_1p: if args.apply_layernorm_1p:
raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.') raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.')
return RMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
elif args.normalization == "LightopRMSNorm":
return LightopRMSNorm(dim=config.hidden_size, return LightopRMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
else: else:
......
...@@ -51,6 +51,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -51,6 +51,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# Standard arguments. # Standard arguments.
parser = _add_network_size_args(parser) parser = _add_network_size_args(parser)
parser = _add_extra_network_size_args(parser)
parser = _add_regularization_args(parser) parser = _add_regularization_args(parser)
parser = _add_training_args(parser) parser = _add_training_args(parser)
parser = _add_extra_training_args(parser) parser = _add_extra_training_args(parser)
...@@ -106,6 +107,18 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -106,6 +107,18 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
return args return args
def _add_extra_network_size_args(parser):
# 删除原参数
remove_original_params(parser, ["normalization"])
# 重定义参数
group = parser.add_argument_group(title='extra network size args')
group.add_argument('--normalization', default='LayerNorm',
choices=['LayerNorm', 'RMSNorm', 'LightopRMSNorm'],
help='Which normalization technique to use.')
return parser
def _add_extra_distributed_args(parser): def _add_extra_distributed_args(parser):
group = parser.add_argument_group(title='extra distributed args') group = parser.add_argument_group(title='extra distributed args')
group.add_argument('--rank', default=-1, type=int, group.add_argument('--rank', default=-1, type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment