Commit 722e38bf authored by dongcl's avatar dongcl
Browse files
parents 2b8d28d0 6fc0ec45
...@@ -91,6 +91,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -91,6 +91,7 @@ class CoreAdaptation(MegatronAdaptationABC):
gpt_model_init, gpt_model_init,
shared_embedding_or_mtp_embedding_weight shared_embedding_or_mtp_embedding_weight
) )
from ..training.utils import get_batch_on_this_tp_rank
# Embedding # Embedding
MegatronAdaptation.register( MegatronAdaptation.register(
...@@ -100,6 +101,8 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -100,6 +101,8 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward', 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward) language_model_embedding_forward)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
...@@ -151,6 +154,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -151,6 +154,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
...@@ -158,6 +162,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -158,6 +162,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init) vocab_parallel_embedding_init)
# VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
VocabParallelCrossEntropy.calculate_predicted_logits)
# _VocabParallelCrossEntropy # _VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
......
...@@ -39,8 +39,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf ...@@ -39,8 +39,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if hasattr(model_module, if (
"share_mtp_embedding_and_output_weight") and model_module.share_mtp_embedding_and_output_weight: hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight() weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr) orig_grad = getattr(weight, grad_attr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment