from functools import partial from types import SimpleNamespace from unittest import mock import torch from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig from megatron.training.training import get_model from megatron.training.utils import unwrap_model NUM_LAYERS = 8 HIDDEN_SIZE = 16 NUM_ATTENTION_HEADS = 8 def initialize_gpt_model( pre_process=True, post_process=True, seed=0, use_glu=True, **config_kwargs ): torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) default_config_kwargs = dict( num_layers=NUM_LAYERS, hidden_size=HIDDEN_SIZE, num_attention_heads=NUM_ATTENTION_HEADS, use_cpu_initialization=True, ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs, gated_linear_unit=use_glu) model = GPTModel( config=transformer_config, transformer_layer_spec=get_gpt_layer_local_spec(), vocab_size=128, max_sequence_length=4, pre_process=pre_process, post_process=post_process, ) model.bfloat16() with torch.no_grad(): for p in model.parameters(): p.random_() return model def initialize_moe_model( pre_process=True, post_process=True, seed=0, use_glu=True, use_sp=False, use_te=False, use_grouped_mlp=False, **config_kwargs ): torch.manual_seed(seed) model_parallel_cuda_manual_seed(seed) expert_num = 8 default_config_kwargs = dict( num_layers=8, hidden_size=16, num_attention_heads=8, use_cpu_initialization=True, num_moe_experts=expert_num, sequence_parallel=use_sp, moe_grouped_gemm=use_grouped_mlp, add_bias_linear=False, ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs, gated_linear_unit=use_glu) if use_te: spec = get_gpt_layer_with_transformer_engine_spec( num_experts=expert_num, moe_grouped_gemm=use_grouped_mlp ) else: spec = get_gpt_layer_local_spec(num_experts=expert_num, moe_grouped_gemm=use_grouped_mlp) model = GPTModel( config=transformer_config, transformer_layer_spec=spec, vocab_size=128, max_sequence_length=4, pre_process=pre_process, post_process=post_process, ) model.bfloat16() with torch.no_grad(): for p in model.parameters(): p.random_() return model def init_basic_mock_args(args, tp, pp, bf16=True): args.data_parallel_random_init = False args.virtual_pipeline_model_parallel_size = None args.fp16 = False args.bf16 = bf16 args.accumulate_allreduce_grads_in_fp32 = False args.overlap_grad_reduce = False args.overlap_param_gather_with_optimizer_step = False args.fp8_param_gather = False args.use_distributed_optimizer = True args.ddp_bucket_size = None args.check_for_nan_in_loss_and_grad = False args.ddp_average_in_collective = False args.tensor_model_parallel_size = tp args.pipeline_model_parallel_size = pp args.encoder_tensor_model_parallel_size = 0 args.encoder_pipeline_model_parallel_size = 0 args.enable_ft_package = False args.use_torch_fsdp2 = False return args def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False): args.non_persistent_global_ckpt_dir = None args.non_persistent_ckpt_type = None args.save = ckpt_dir args.load = ckpt_dir args.pretrained_checkpoint = None args.ckpt_fully_parallel_save = fully_parallel args.ckpt_fully_parallel_load = fully_parallel args.async_save = False args.use_dist_ckpt = True args.ckpt_format = 'torch_dist' args.no_save_optim = False args.no_save_rng = False args.ckpt_assume_constant_structure = False args.log_progress = False args.auto_detect_ckpt_format = False args.exit_on_missing_checkpoint = False args.finetune = False args.consumed_train_samples = 0 args.skipped_train_samples = 0 args.consumed_valid_samples = 0 args.retro_add_retriever = False args.no_load_optim = False args.no_load_rng = False args.dist_ckpt_strictness = 'assume_ok_unexpected' args.add_position_embedding = True args.vocab_file = False args.num_layers = NUM_LAYERS args.hidden_size = HIDDEN_SIZE args.num_attention_heads = NUM_ATTENTION_HEADS def setup_model_and_optimizer( seed, tp, pp, initialize_fn=initialize_gpt_model, bf16=True, dist_opt=True ): mock_args = SimpleNamespace() with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) model = get_model( partial( initialize_fn, seed=seed, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16, ) ) config = OptimizerConfig( bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, use_distributed_optimizer=dist_opt, ) optimizer = get_megatron_optimizer(config, model) torch.manual_seed(seed + 1) model_parallel_cuda_manual_seed(seed + 1) for group in optimizer.optimizer.param_groups: for p in group['params']: if len(optimizer.optimizer.state[p]) == 0: optimizer.optimizer.state[p]['exp_avg'] = torch.rand_like(p.data) optimizer.optimizer.state[p]['exp_avg_sq'] = torch.rand_like(p.data) optimizer.reload_model_params() return unwrap_model(model), optimizer def setup_moe_model_and_optimizer( seed, tp, pp, ep, initialize_fn=initialize_moe_model, bf16=True, dist_opt=True, use_te=False, use_grouped_mlp=False, use_glu=False, ): mock_args = SimpleNamespace() with mock.patch('megatron.training.training.get_args', new=lambda: mock_args): init_basic_mock_args(mock_args, tp, pp, bf16=bf16) model = get_model( partial( initialize_fn, seed=seed, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp, pipeline_dtype=torch.bfloat16, expert_model_parallel_size=ep, use_sp=(tp > 1 and ep > 1), use_te=use_te, use_grouped_mlp=use_grouped_mlp, use_glu=use_glu, ) ) config = OptimizerConfig( bf16=bf16, params_dtype=torch.bfloat16 if bf16 else torch.float, use_distributed_optimizer=dist_opt, ) optimizer = get_megatron_optimizer(config, model) torch.manual_seed(seed + 1) model_parallel_cuda_manual_seed(seed + 1) for opt in optimizer.chained_optimizers: for group in opt.param_groups: for p in group['params']: if len(opt.state[p]) == 0: opt.state[p]['exp_avg'] = torch.rand_like(p.data) opt.state[p]['exp_avg_sq'] = torch.rand_like(p.data) optimizer.reload_model_params() return unwrap_model(model), optimizer