Commit 9800dec4 authored by dongcl's avatar dongcl
Browse files

add LightopRMSNorm

parent 0604509a
...@@ -116,7 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -116,7 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# GPT Model # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight) setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
...@@ -240,6 +242,7 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -240,6 +242,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
def patch_legacy_models(self): def patch_legacy_models(self):
from ..legacy.model.transformer import ParallelMLP, ParallelAttention from ..legacy.model.transformer import ParallelMLP, ParallelAttention
from ..legacy.model.utils import get_norm
# ParallecMLP # ParallecMLP
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelMLP.__init__',
...@@ -252,6 +255,8 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -252,6 +255,8 @@ class LegacyAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.legacy.model.rms_norm.RMSNorm.forward', MegatronAdaptation.register('megatron.legacy.model.rms_norm.RMSNorm.forward',
torch.compile(mode="max-autotune-no-cudagraphs"), torch.compile(mode="max-autotune-no-cudagraphs"),
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.legacy.model.utils.get_norm',
get_norm)
MegatronAdaptation.execute() MegatronAdaptation.execute()
...@@ -22,154 +22,48 @@ from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPr ...@@ -22,154 +22,48 @@ from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPr
from dcu_megatron.core.transformer.transformer_config import TransformerConfig from dcu_megatron.core.transformer.transformer_config import TransformerConfig
def gpt_model_init( def gpt_model_init_wrapper(fn):
self, @wraps(fn)
config: TransformerConfig, def wrapper(self, *args, **kwargs):
transformer_layer_spec: ModuleSpec, fn(self, *args, **kwargs)
vocab_size: int,
max_sequence_length: int, # add mtp
pre_process: bool = True, self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
post_process: bool = True, if self.num_nextn_predict_layers:
fp16_lm_cross_entropy: bool = False, assert hasattr(self.config, "mtp_spec")
parallel_output: bool = True, self.mtp_spec: ModuleSpec = self.config.mtp_spec
share_embeddings_and_output_weights: bool = False, self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', self.recompute_mtp_norm = self.config.recompute_mtp_norm
rotary_percent: float = 1.0, self.recompute_mtp_layer = self.config.recompute_mtp_layer
rotary_base: int = 10000, self.mtp_loss_scale = self.config.mtp_loss_scale
rope_scaling: bool = False, if self.post_process and self.training:
rope_scaling_factor: float = 8.0, self.mtp_layers = torch.nn.ModuleList(
scatter_embedding_sequence_parallel: bool = True, [
seq_len_interpolation_factor: Optional[float] = None, MultiTokenPredictor(
mtp_spec: ModuleSpec = None config,
) -> None: self.mtp_spec.submodules,
super(GPTModel, self).__init__(config=config) vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
if has_config_logger_enabled(config): layer_number=i,
log_config_to_disk(config, locals(), prefix=type(self).__name__) pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec parallel_output=self.parallel_output,
self.vocab_size = vocab_size position_embedding_type=self.position_embedding_type,
self.max_sequence_length = max_sequence_length rotary_percent=self.rotary_percent,
self.pre_process = pre_process seq_len_interpolation_factor=seq_len_interpolation_factor,
self.post_process = post_process share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy recompute_mtp_norm=self.recompute_mtp_norm,
self.parallel_output = parallel_output recompute_mtp_layer=self.recompute_mtp_layer,
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights add_output_layer_bias=False
self.position_embedding_type = position_embedding_type )
for i in range(self.num_nextn_predict_layers)
# megatron core pipelining currently depends on model type ]
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
# add mtp
self.mtp_spec: ModuleSpec = mtp_spec
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training and self.num_nextn_predict_layers:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
) )
for i in range(self.num_nextn_predict_layers)
]
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config): if self.pre_process or self.post_process:
log_config_to_disk( setup_mtp_embeddings(self)
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
if self.num_nextn_predict_layers and (self.pre_process or self.post_process): return wrapper
setup_mtp_embeddings(self)
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor: def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
......
...@@ -199,4 +199,3 @@ def transformer_block_forward( ...@@ -199,4 +199,3 @@ def transformer_block_forward(
) )
return hidden_states return hidden_states
import torch
from typing import Any, Callable, Dict, Optional, Tuple, Union
import lightop # rmsnorm_forward,rmsnorm_backward
from functools import partial
from megatron.core.utils import is_torch_min_version
if is_torch_min_version("2.4.0a0"):
custom_fwd = partial(torch.amp.custom_fwd, device_type="cuda")
custom_bwd = partial(torch.amp.custom_bwd, device_type="cuda")
else:
custom_fwd = torch.cuda.amp.custom_fwd
custom_bwd = torch.cuda.amp.custom_bwd
class _LightopRMSNorm(torch.autograd.Function):
""" 使用lightop实现rmsnorm"""
@staticmethod
# @custom_fwd
def forward(ctx,
inp: torch.Tensor,
ln_out: torch.Tensor,
weight: torch.Tensor,
eps: float,
is_grad_enabled):
output = lightop.rmsnorm_forward(inp, weight, ln_out, eps, training=True)# output = (output, weight)
rsigma = output[1]
if is_grad_enabled:
ctx.save_for_backward(inp, weight, ln_out, rsigma)
return output[0]
@staticmethod
# @custom_bwd
def backward(ctx, grad_output):
inp, weight, ln_out, rsigma = ctx.saved_tensors
dgrad, dgamma = lightop.rmsnorm_backward(grad_output, inp, rsigma, weight)
return dgrad, None, dgamma, None, None, None, None, None, None
class LightopRMSNorm(torch.nn.Module):
def __init__(self,
dim: int,
eps: float = 1e-6,):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
"""
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
# @no_torch_dynamo() # 动态torch._dynamo.disable
def forward(self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None):
if torch.is_grad_enabled():
fwd_fn = _LightopRMSNorm.apply
args = []
else:
fwd_fn = _LightopRMSNorm.forward
args = [None]
ln_out = torch.empty_like(inp, dtype=inp.dtype, memory_format=torch.contiguous_format)
args += (inp, ln_out, self.weight, self.eps, torch.is_grad_enabled())
out = fwd_fn(*args)
return out
from megatron.training import get_args
from megatron.legacy.model import LayerNorm
from .rms_norm import LightopRMSNorm
def get_norm(config):
args = get_args()
if args.normalization == "LayerNorm":
return LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
elif args.normalization == "RMSNorm":
if args.apply_layernorm_1p:
raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.')
return LightopRMSNorm(dim=config.hidden_size,
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
else:
raise Exception(f"unsupported norm type '{args.normalization}'.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment