Commit 89d29a02 authored by silencealiang's avatar silencealiang
Browse files

bug fix

parent 81e19772
...@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
) )
from ..core.models.gpt.gpt_model import ( from ..core.models.gpt.gpt_model import (
gpt_model_forward, gpt_model_forward,
gpt_model_init_wrapper, gpt_model_init,
shared_embedding_or_output_weight, shared_embedding_or_output_weight,
) )
from ..core.models.common.language_module.language_module import ( from ..core.models.common.language_module.language_module import (
...@@ -130,9 +130,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -130,9 +130,7 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight', 'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight) shared_embedding_or_output_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
gpt_model_init_wrapper,
apply_wrapper=True)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
...@@ -152,9 +150,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -152,9 +150,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', # MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), # torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True) # apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
......
...@@ -25,29 +25,118 @@ from dcu_megatron.core.transformer.transformer_config import TransformerConfig ...@@ -25,29 +25,118 @@ from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn): def gpt_model_init(
@wraps(fn) self,
def wrapper(self, *args, **kwargs): config: TransformerConfig,
fn(self, *args, **kwargs) transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super(GPTModel, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0): if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0):
self.embedding = LanguageModelEmbedding( self.embedding = LanguageModelEmbedding(
config=self.config, config=self.config,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length, max_sequence_length=self.max_sequence_length,
position_embedding_type=kwargs.get("position_embedding_type"), position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=kwargs.get("scatter_embedding_sequence_parallel"), scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
) )
if ( # Output
self.post_process if post_process:
and int(os.getenv("USE_FLUX_OVERLAP", "0")) if self.config.defer_embedding_wgrad_compute:
): # The embedding activation buffer preserves a reference to the input activations
self.output_layer = FluxColumnParallelLinear( # of the final embedding projection layer GEMM. It will hold the activations for
self.config.hidden_size, # all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
config.hidden_size,
self.vocab_size, self.vocab_size,
config=self.config, config=config,
init_method=self.config.init_method, init_method=config.init_method,
bias=False, bias=False,
skip_bias_add=False, skip_bias_add=False,
gather_output=not self.parallel_output, gather_output=not self.parallel_output,
...@@ -57,13 +146,10 @@ def gpt_model_init_wrapper(fn): ...@@ -57,13 +146,10 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer=self.grad_output_buffer, grad_output_buffer=self.grad_output_buffer,
) )
self.setup_embeddings_and_output_layer()
# add mtp # add mtp
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.num_nextn_predict_layers: if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec") assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec self.mtp_spec = self.config.mtp_spec
self.recompute_mtp_norm = self.config.recompute_mtp_norm self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale self.mtp_loss_scale = self.config.mtp_loss_scale
...@@ -81,7 +167,7 @@ def gpt_model_init_wrapper(fn): ...@@ -81,7 +167,7 @@ def gpt_model_init_wrapper(fn):
parallel_output=self.parallel_output, parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type, position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent, rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=kwargs.get("seq_len_interpolation_factor", None), seq_len_interpolation_factor=seq_len_interpolation_factor,
recompute_mtp_norm=self.recompute_mtp_norm, recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer, recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False add_output_layer_bias=False
...@@ -90,7 +176,13 @@ def gpt_model_init_wrapper(fn): ...@@ -90,7 +176,13 @@ def gpt_model_init_wrapper(fn):
] ]
) )
return wrapper if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def shared_embedding_or_output_weight(self) -> Tensor: def shared_embedding_or_output_weight(self) -> Tensor:
......
...@@ -144,7 +144,7 @@ class MultiTokenPredictor(MegatronModule): ...@@ -144,7 +144,7 @@ class MultiTokenPredictor(MegatronModule):
"""Forward function of the MTP module""" """Forward function of the MTP module"""
# Decoder embedding. # Decoder embedding.
decoder_input = embedding( decoder_input = embedding_layer(
input_ids=embed_input_ids, input_ids=embed_input_ids,
position_ids=position_ids, position_ids=position_ids,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment