Commit 89d29a02 authored by silencealiang's avatar silencealiang
Browse files

bug fix

parent 81e19772
......@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init_wrapper,
gpt_model_init,
shared_embedding_or_output_weight,
)
from ..core.models.common.language_module.language_module import (
......@@ -130,9 +130,7 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
......@@ -152,9 +150,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......
......@@ -25,45 +25,131 @@ from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0):
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=kwargs.get("position_embedding_type"),
scatter_to_sequence_parallel=kwargs.get("scatter_embedding_sequence_parallel"),
)
def gpt_model_init(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super(GPTModel, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if (
self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
self.setup_embeddings_and_output_layer()
if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0):
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
# add mtp
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec
self.mtp_spec = self.config.mtp_spec
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
......@@ -81,7 +167,7 @@ def gpt_model_init_wrapper(fn):
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=kwargs.get("seq_len_interpolation_factor", None),
seq_len_interpolation_factor=seq_len_interpolation_factor,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
......@@ -90,7 +176,13 @@ def gpt_model_init_wrapper(fn):
]
)
return wrapper
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def shared_embedding_or_output_weight(self) -> Tensor:
......
......@@ -144,7 +144,7 @@ class MultiTokenPredictor(MegatronModule):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = embedding(
decoder_input = embedding_layer(
input_ids=embed_input_ids,
position_ids=position_ids,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment