Commit bf220124 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 9800dec4
...@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
) )
from ..core.models.gpt.gpt_model import ( from ..core.models.gpt.gpt_model import (
gpt_model_forward, gpt_model_forward,
gpt_model_init, gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight shared_embedding_or_mtp_embedding_weight
) )
from ..training.utils import get_batch_on_this_tp_rank from ..training.utils import get_batch_on_this_tp_rank
......
...@@ -136,8 +136,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -136,8 +136,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
else: else:
mtp_transformer_layer_spec = transformer_layer_spec mtp_transformer_layer_spec = transformer_layer_spec
mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel( model = GPTModel(
config=config, config=config,
transformer_layer_spec=transformer_layer_spec, transformer_layer_spec=transformer_layer_spec,
...@@ -151,8 +151,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -151,8 +151,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type, position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent, rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling, rope_scaling=args.use_rope_scaling
mtp_spec=mtp_spec
) )
# model = torch.compile(model,mode='max-autotune-no-cudagraphs') # model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model) print_rank_0(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment