Commit bf220124 authored by dongcl's avatar dongcl
Browse files

bug fix

parent 9800dec4
......@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init,
gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight
)
from ..training.utils import get_batch_on_this_tp_rank
......
......@@ -136,8 +136,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
else:
mtp_transformer_layer_spec = transformer_layer_spec
mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
......@@ -151,8 +151,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_spec=mtp_spec
rope_scaling=args.use_rope_scaling
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment