model_provider.py 13.2 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""ModelOpt GPT model provider."""

import json
import os
from argparse import Namespace
from typing import Any, Dict

import modelopt.torch.distill as mtd
import modelopt.torch.opt as mto
import yaml

from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
    get_gpt_heterogeneous_layer_spec,
)
from megatron.core.models.mamba import MambaModel as MCoreMambaModel
from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec
from megatron.core.post_training.modelopt.gpt.state_dict_hooks import (
    mcore_gpt_load_te_state_dict_pre_hook,
)
from megatron.post_training.algos import distillation
from megatron.post_training.checkpointing import load_modelopt_checkpoint, load_modelopt_state
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args


def count_parameters_in_layer(model, layer_name):
    num_params = 0
    for name, param in model.named_parameters():
        if layer_name in name:
            num_params += param.numel()
            print_rank_0(f" - {name}: {param.numel()}")
    return num_params

def _add_load_convert_hooks(model: MCoreGPTModel):
    """Register some load_state_dict prehooks to handle some known state_dict key mismatch.
    """
    args = get_args()
    if args.export_te_mcore_model:
        model._register_load_state_dict_pre_hook(mcore_gpt_load_te_state_dict_pre_hook)


def _load_teacher_model_config(checkpoint_path: str) -> Namespace:
    """Reads teacher config from a file.

    The config provided via --teacher-model-config should specify
    (in NEMO format) any model architecture settings which differ from the main student model's.
    This function will translate NEMO field names to MCore as needed.
    """
    required_teacher_fields = (
        "num_layers",
        "hidden_size",
        "ffn_hidden_size",
        "num_attention_heads",
    )

    args = get_args()
    config_path = os.path.join(checkpoint_path, "model_config.yaml") if args.teacher_model_config is None else args.teacher_model_config
    if not os.path.exists(config_path):
        raise FileNotFoundError(
            "Teacher checkpoint dir must contain a NEMO-format yaml config named 'model_config.yaml'"
        )
    with open(config_path) as f:
        config = yaml.safe_load(f)

    missing_keys = [k for k in required_teacher_fields if k not in config]
    if missing_keys:
        raise ValueError(
            f"Teacher `model_config.yaml` file missing the following fields: {missing_keys}"
        )

    if "encoder_seq_length" in config:
        config["seq_length"] = config["encoder_seq_length"]
    if "bias" in config:
        config["disable_bias_linear"] = not config["bias"]
    if config.get("activation") == "swiglu":
        config["swiglu"] = True
    if config.get("position_embedding_type", False) is None:
        config["use_rotary_position_embeddings"] = config["no_position_embedding"] = True
    if "share_embeddings_and_output_weights" in config:
        config["untie_embeddings_and_output_weights"] = not config[
            "share_embeddings_and_output_weights"
        ]
    if "tokenizer" in config:
        config["tokenizer_type"] = config["tokenizer"]["type"]
        config["tokenizer_model"] = config["tokenizer"]["model"]
    if "masked_softmax_fusion" in config:
        config["no_masked_softmax_fusion"] = not config["masked_softmax_fusion"]
    if config.get("normalization") == "layernorm1p":
        config["apply_layernorm_1p"] = True
    if "precision" in config:
        config[config["precision"]] = True
    if "mcore_gpt" in config:
        config["use_mcore_models"] = config["mcore_gpt"]

    args_dict = vars(get_args()).copy()
    del args_dict["kv_channels"]  # not recalculated if present
    args_dict.update(config)

    return Namespace(**args_dict)


def _teacher_provider(config: Namespace, model_kwargs: Dict[str, Any]) -> MCoreGPTModel:
    """Teacher model factory (must be a non-local function to pickle)."""
    args = get_args()

    # Convert to `TransformerConfig` here to avoid ModelOpt pickling issues (contains local functions)
    config = core_transformer_config_from_args(config)

    if config.is_hybrid_model:
        teacher = MCoreMambaModel(config=config, **model_kwargs)
    else:
        teacher = MCoreGPTModel(config=config, **model_kwargs)
    _add_load_convert_hooks(teacher)

    print_rank_0("Loading teacher {} checkpoint...".format("MCoreMambaModel" if config.is_hybrid_model else "MCoreGPTModel"))
    # [WAR]: load checkpoint will check checkpoint's saved args and rng state if not finetune.
    # To avoid error out on loading teacher's checkpoint, we temporarily set args.finetune to
    # True while loading the teacher checkpoint.
    original_args_finetune, original_ckpt_format = args.finetune, args.ckpt_format
    args.finetune = True
    if args.export_kd_teacher_ckpt_format is not None:
        args.ckpt_format = args.export_kd_teacher_ckpt_format
    load_modelopt_checkpoint([teacher], load_arg='export_kd_teacher_load')
    args.finetune, args.ckpt_format = original_args_finetune, original_ckpt_format
    print_rank_0("successfully loaded teacher...")

    return teacher


def model_provider(pre_process=True, post_process=True, parallel_output=True) -> MCoreGPTModel:
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
        parallel_output (bool): whether to allgather the output logits? This must be
            True if `model_provider` is called in text_generation_server.

    Returns:
        MCoreGPTModel: The returned model
    """
    args = get_args()

    print_rank_0("building GPT model ...")

    # ModelOpt by default assumes none homogenous layers. This affect the storage format of the sharded checkpoint.
    config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        raise ValueError(
            "ModelOpt integration only support MCore models. Use --use-mcore-modules instead."
        )
    if args.spec is not None:
        raise ValueError("ModelOpt integration does not support custom args.spec.")

    # Llama-4 Scout/Maverick support
    config.qk_l2_norm = args.export_qk_l2_norm 
    config.moe_apply_probs_on_input = args.export_moe_apply_probs_on_input 

    if args.export_model_type == "GPTModel":
        if args.export_offline_model:
            # Record the original num_layers. This is needed for _set_default_aux_hidden_state_layers
            config.original_num_layers = config.num_layers
            # Set num_layers to 0 for base model in offline mode
            config.num_layers = 0
            # SP is not used for offline
            # TODO: DSR1 MTP may require SP
            config.sequence_parallel = False
        if config.heterogeneous_block_specs:
            transformer_layer_spec = get_gpt_heterogeneous_layer_spec(
                config=config,
                use_te=args.transformer_impl == "transformer_engine",
            )
        else:
            local_core_attention=args.export_force_local_attention
            if config.context_parallel_size > 1:
                print_rank_0("context_parallel_size > 1! Force using TEDotProductAttention!")
                local_core_attention=False
                print_rank_0("context_parallel_size > 1! Force attention_mask_type to Causal. This can be wrong for EAGLE training!")
                use_arbitrary_attention_mask = False
            else:
                use_arbitrary_attention_mask = True

            transformer_layer_spec = get_gpt_modelopt_spec(
                config=config,
                local_core_attention=local_core_attention,
                remap_te_layernorm=args.export_te_mcore_model,
                real_quant_cfg=args.export_real_quant_cfg,
                use_arbitrary_attention_mask=use_arbitrary_attention_mask,
            )

        model_kwargs = {
            "transformer_layer_spec": transformer_layer_spec,
            "vocab_size": args.padded_vocab_size,
            "max_sequence_length": args.max_position_embeddings,
            "pre_process": pre_process,
            "post_process": post_process,
            "fp16_lm_cross_entropy": args.fp16_lm_cross_entropy,
            "parallel_output": parallel_output,
            "share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights,
            "position_embedding_type": args.position_embedding_type,
            "rotary_percent": args.rotary_percent,
            "rotary_base": args.rotary_base,
            "rope_scaling": args.use_rope_scaling,
        }
        model = MCoreGPTModel(config=config, **model_kwargs)
    elif args.export_model_type == "MambaModel" or args.is_hybrid_model:
        from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec

        mamba_stack_spec = get_mamba_stack_modelopt_spec(
            remap_te_layernorm=args.export_te_mcore_model
        )
        model_kwargs = {
            "mamba_stack_spec": mamba_stack_spec,
            "vocab_size": args.padded_vocab_size,
            "max_sequence_length": args.max_position_embeddings,
            "pre_process": pre_process,
            "hybrid_attention_ratio": args.hybrid_attention_ratio,
            "hybrid_mlp_ratio": args.hybrid_mlp_ratio,
            "hybrid_override_pattern": args.hybrid_override_pattern,
            "post_process": post_process,
            "fp16_lm_cross_entropy": args.fp16_lm_cross_entropy,
            "parallel_output": True,
            "share_embeddings_and_output_weights": not args.untie_embeddings_and_output_weights,
            "position_embedding_type": args.position_embedding_type,
            "rotary_percent": args.rotary_percent,
            "rotary_base": args.rotary_base,
        }

        model = MCoreMambaModel(config=config, **model_kwargs)

        for l in range(model.decoder.num_layers_per_pipeline_rank):
            layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.')
            print_rank_0(f" == params layer {l}: {layer_params}")

    else:
        raise ValueError("ModelOpt does not support model type {}".format(args.export_model_type))

    # [IMPORTANT] Load modelopt_state immediately before returning the model back to `get_model()`.
    # 
    # ModelOpt can create additional trainable parameters (e.g. for online speculative
    # decoding training or PEFT). Hence resuming modelopt_state during checkpoint loading is already
    # too late since Megatron created the optimizer right after calling model_provider before loading
    # the checkpoint. To ensure all trainable parameters are reigistered, we try to resume the
    # modelopt_state (which transforms the model to have additional parameters) before returning.
    if args.load is not None:
        load_modelopt_state(model=model)

    _add_load_convert_hooks(model)

    # Distillation mode.
    if args.export_kd_teacher_load:
        print_rank_0("Distillation: Enabled.")

        # NOTE: Unknown memory leak occuring per fwd-bwd pass if model
        # is converted to a `modelopt.torch.opt.DynamicModule`.
        # Argument `--manual-gc` can result in an eventual OOM.
        assert (
            not args.manual_gc
        ), "ModelOpt Distillation currently incompatible with `--manual-gc` option."
        assert (
            not args.tp_comm_overlap
        ), "ModelOpt Distillation currently incompatible with `--tp-comm-overlap` option."
        if args.pipeline_model_parallel_size > 1:
            assert (
                args.virtual_pipeline_model_parallel_size is None
            ), "ModelOpt Distillation currently incompatible with interleaved pipeline schedule."

        teacher_config = _load_teacher_model_config(args.export_kd_teacher_load)
        distill_cfg = distillation.load_distillation_config(
            args.export_kd_cfg, student_cfg=config, teacher_cfg=core_transformer_config_from_args(teacher_config)
        )
        if "hybrid_override_pattern" in teacher_config and args.is_hybrid_model:
            model_kwargs["hybrid_override_pattern"] = teacher_config.hybrid_override_pattern
        if "hybrid_attention_ratio" in teacher_config and args.is_hybrid_model:
            model_kwargs["hybrid_attention_ratio"] = teacher_config.hybrid_attention_ratio
        if "hybrid_mlp_ratio" in teacher_config and args.is_hybrid_model:
            model_kwargs["hybrid_mlp_ratio"] = teacher_config.hybrid_mlp_ratio

        kd_config = {
            "teacher_model": (_teacher_provider, [teacher_config, model_kwargs], {}),
            "criterion": distill_cfg["criterion"],
            "loss_balancer": distill_cfg["loss_balancer"],
        }
        model = mtd.convert(model, mode=[("kd_loss", kd_config)])

        # Additional tweaks needed for MCore/Nemo.
        # NOTE: Distillation state manually removed in this function.
        # ModelOpt state restoration above will not return a `mtd.DistillationModel` for simplicity reasons.
        distillation.adjust_distillation_model_for_mcore(model, distill_cfg)

    return model