# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_decoder_block_spec, get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( get_gpt_heterogeneous_layer_spec, ) from megatron.core.transformer.spec_utils import import_module from megatron.training import get_args, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from megatron.training.yaml_arguments import core_transformer_config_from_yaml import megatron.legacy.model # isort: skip # NOTE: Loading `megatron.legacy.model` earlier fails due to circular import def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None): print_rank_0('building GPT model ...') if config is None: if args.yaml_cfg is not None: config = core_transformer_config_from_yaml(args, "language_model") else: config = core_transformer_config_from_args(args) if args.use_legacy_models: model = megatron.legacy.model.GPTModel( config, num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process, ) else: # using core models if args.spec is not None: transformer_layer_spec = import_module(args.spec) else: use_te = args.transformer_impl == "transformer_engine" if args.num_experts: # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage, ) elif args.heterogeneous_layers_config_path is not None: transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) else: # Define the decoder layer spec transformer_layer_spec = _get_transformer_layer_spec(use_te, config) mtp_block_spec = None if args.mtp_num_layers is not None: if ( hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0 ): # Get the decoder layer spec explicitly if no decoder layer in the last stage, # Only happens with block spec (TransformerBlockSubmodules) when using MoE. transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config) else: transformer_layer_spec_for_mtp = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec( config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, vp_stage=vp_stage, ) model = GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, vocab_size=args.padded_vocab_size, max_sequence_length=args.max_position_embeddings, pre_process=pre_process, post_process=post_process, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, parallel_output=True, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, rotary_base=args.rotary_base, rope_scaling=args.use_rope_scaling, mtp_block_spec=mtp_block_spec, vp_stage=vp_stage, ) return model def _get_transformer_layer_spec(use_te, config): """Get transformer layer specification based on configuration. Args: use_te (bool): Whether to use Transformer Engine args: Training arguments config: Model configuration Returns: transformer_layer_spec: The transformer layer specification """ args = get_args() if use_te: return get_gpt_layer_with_transformer_engine_spec( args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, qk_l2_norm=args.qk_l2_norm, use_kitchen=config.use_kitchen, ) else: return get_gpt_layer_local_spec( args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, normalization=args.normalization, use_kitchen=config.use_kitchen, )