Commit 722e38bf authored by dongcl's avatar dongcl
Browse files
parents 2b8d28d0 6fc0ec45
......@@ -91,6 +91,7 @@ class CoreAdaptation(MegatronAdaptationABC):
gpt_model_init,
shared_embedding_or_mtp_embedding_weight
)
from ..training.utils import get_batch_on_this_tp_rank
# Embedding
MegatronAdaptation.register(
......@@ -100,6 +101,8 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
......@@ -151,6 +154,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
......@@ -158,6 +162,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init)
# VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
VocabParallelCrossEntropy.calculate_predicted_logits)
# _VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward',
torch.compile(mode='max-autotune-no-cudagraphs'),
......
......@@ -39,8 +39,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if hasattr(model_module,
"share_mtp_embedding_and_output_weight") and model_module.share_mtp_embedding_and_output_weight:
if (
hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment