Commit 56400eb5 authored by wangxj's avatar wangxj
Browse files

修改legacy的norm

parent 4d19cbac
Pipeline #2634 passed with stage
......@@ -31,3 +31,84 @@ class RMSNorm(torch.nn.Module):
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
from typing import Any, Callable, Dict, Optional, Tuple, Union
import lightop # rmsnorm_forward,rmsnorm_backward
from functools import partial
from megatron.core.utils import is_torch_min_version
if is_torch_min_version("2.4.0a0"):
custom_fwd = partial(torch.amp.custom_fwd, device_type="cuda")
custom_bwd = partial(torch.amp.custom_bwd, device_type="cuda")
else:
custom_fwd = torch.cuda.amp.custom_fwd
custom_bwd = torch.cuda.amp.custom_bwd
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
class _LightopRMSNorm(torch.autograd.Function):
""" 使用lightop实现rmsnorm"""
@staticmethod
# @custom_fwd
def forward(ctx,
inp: torch.Tensor,
weight: torch.Tensor,
ln_out: torch.Tensor,
eps: float,
is_grad_enabled):
output = lightop.rmsnorm_forward(inp, weight, ln_out, eps, training=True)# output = (output, weight)
# print_rank_0(f"_LightopRMSNorm: output({output[0].shape, output[1].shape}) = lightop.rmsfwd(inp{inp.shape}, weight{weight.shape}, ...)")
rsigma = output[1]
if is_grad_enabled:
ctx.save_for_backward(inp, weight, rsigma)
return output[0]
@staticmethod
# @custom_bwd
def backward(ctx, grad_output):
inp, weight, rsigma = ctx.saved_tensors
dgrad, dgamma = lightop.rmsnorm_backward(grad_output, inp, rsigma, weight)
# print_rank_0(f"_LightopRMSNorm: dgrad{dgrad.shape}, dgamma{dgamma.shape} = lightop.rmsbwd(grad_output{grad_output.shape}, inp{inp.shape}, rsigma{rsigma.shape}, weight{weight.shape})")
return dgrad, dgamma, None, None, None
class LightopRMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
"""
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
# @no_torch_dynamo() # 动态torch._dynamo.disable
def forward(self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None):
if torch.is_grad_enabled():
fwd_fn = _LightopRMSNorm.apply
args = []
else:
fwd_fn = _LightopRMSNorm.forward
args = [None]
ln_out = torch.empty_like(inp, dtype=inp.dtype, memory_format=torch.contiguous_format)
args += (inp, self.weight, ln_out, self.eps, torch.is_grad_enabled())
out = fwd_fn(*args)
return out
......@@ -7,7 +7,7 @@ import math
import torch
from megatron.training import get_args
from megatron.legacy.model import LayerNorm, RMSNorm
from megatron.legacy.model import LayerNorm, RMSNorm, LightopRMSNorm
from megatron.core.jit import jit_fuser
def init_method_normal(sigma):
......@@ -75,5 +75,8 @@ def get_norm(config):
return RMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
elif args.normalization == "LightopRMSNorm":
return LightopRMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon)
else:
raise Exception(f"unsupported norm type '{args.normalization}'.")
......@@ -1112,7 +1112,7 @@ def _add_network_size_args(parser):
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--normalization', default='LayerNorm',
choices=['LayerNorm', 'RMSNorm'],
choices=['LayerNorm', 'RMSNorm', 'LightopRMSNorm'],
help='Which normalization technique to use.')
group.add_argument('--norm-epsilon', type=float, default=1e-5,
help='Epsilon for layer norm and RMS norm.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment