# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import math import torch from megatron.core.dist_checkpointing import save, load, load_plain_tensors from megatron.core import parallel_state from megatron.core.dist_checkpointing.dict_utils import diff from megatron.core.dist_checkpointing.serialization import \ get_default_save_sharded_strategy, get_default_load_sharded_strategy from megatron.core.dist_checkpointing.strategies.fully_parallel import \ FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper from megatron.core.dist_checkpointing.validation import StrictHandling from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils def common_test_simple_sharded_state_dict_save_load( initialize_model_fn, tmp_path_dist_ckpt, src_layer_spec_fn, dst_layer_spec_fn): """ Simple save and load sanity check, without any equality tests. """ tp = 2 pp = 4 Utils.initialize_model_parallel(tp, pp) gpt_model = initialize_model_fn(1, src_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model') as ckpt_dir: # Save sharded_state_dict = gpt_model.sharded_state_dict() save(sharded_state_dict, ckpt_dir) # Load gpt_model = initialize_model_fn(2, dst_layer_spec_fn, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) sharded_state_dict = gpt_model.sharded_state_dict() state_dict, missing_keys, unexpected_keys = load(sharded_state_dict, ckpt_dir, strict=StrictHandling.RETURN_ALL) # Potential mismatch is because of extra states which is ok assert all('_extra_state' in k for k in missing_keys) assert all('_extra_state' in k for k in unexpected_keys) gpt_model.load_state_dict(state_dict) Utils.destroy_model_parallel() def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, src_layer_spec_fn, dst_layer_spec_fn, use_fpsl, load_order="tp-dp-pp", store_order="tp-dp-pp"): """ Test model saving and loading with different TP/PP """ with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A') as ckpt_dir_A, \ TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B') as ckpt_dir_B: # Save checkpoint A Utils.initialize_model_parallel(*src_tp_pp, order=load_order) gpt_model_A = initialize_model_fn(1, src_layer_spec_fn, tensor_model_parallel_size=src_tp_pp[0], pipeline_model_parallel_size=src_tp_pp[1]) save_strategy = get_default_save_sharded_strategy() if use_fpsl: save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, parallel_state.get_data_parallel_group(with_context_parallel=True), True ) save(gpt_model_A.sharded_state_dict(), ckpt_dir_A, save_strategy) regular_state_dict_A = gpt_model_A.state_dict() Utils.destroy_model_parallel() # Load checkpoint A with different TP/PP and save as checkpoint B # No FPS this time, only FPL Utils.initialize_model_parallel(*dest_tp_pp, order=store_order) gpt_model_B = initialize_model_fn(2, dst_layer_spec_fn, tensor_model_parallel_size=dest_tp_pp[0], pipeline_model_parallel_size=dest_tp_pp[1]) if use_fpsl: load_strategy = get_default_load_sharded_strategy(ckpt_dir_A) load_strategy = FullyParallelLoadStrategyWrapper(load_strategy) else: load_strategy = None state_dict, missing_keys, unexpected_keys = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A, load_strategy, strict=StrictHandling.RETURN_ALL) # Potential mismatch is because of extra states which is ok assert all('_extra_state' in k for k in missing_keys) assert all('_extra_state' in k for k in unexpected_keys) gpt_model_B.load_state_dict(state_dict) save(gpt_model_B.sharded_state_dict(), ckpt_dir_B) regular_state_dict_B = gpt_model_A.state_dict() Utils.destroy_model_parallel() # Test both checkpoints are equal Utils.initialize_model_parallel(1, 1) plain_state_dict_A = load_plain_tensors(ckpt_dir_A) plain_state_dict_B = load_plain_tensors(ckpt_dir_B) diffs = diff(plain_state_dict_A, plain_state_dict_B) assert not any(map(bool, diffs)), diffs # Test both regular state dicts are equal, turning FP8 states to bytes first regular_state_dict_A = {k: v for k, v in regular_state_dict_A.items() if not k.endswith('_extra_state')} regular_state_dict_B = {k: v for k, v in regular_state_dict_B.items() if not k.endswith('_extra_state')} diffs = diff(regular_state_dict_A, regular_state_dict_B) assert not any(map(bool, diffs)), diffs Utils.destroy_model_parallel() def common_test_state_dict_comparison(initialize_model_fn, tmp_path_dist_ckpt): tp = 2 pp = 4 Utils.initialize_model_parallel(tp, pp) with TempNamedDir(tmp_path_dist_ckpt / 'test_state_dict_comparison_A') as ckpt_dir_A, \ TempNamedDir(tmp_path_dist_ckpt / 'test_state_dict_comparison_B') as ckpt_dir_B: gpt_model_A = initialize_model_fn(1, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) save(gpt_model_A.sharded_state_dict(), ckpt_dir_A) gpt_model_B = initialize_model_fn(2, tensor_model_parallel_size=tp, pipeline_model_parallel_size=pp) save(gpt_model_B.sharded_state_dict(), ckpt_dir_B) state_dict_A = load_plain_tensors(ckpt_dir_A) state_dict_A_dup = load_plain_tensors(ckpt_dir_A) state_dict_B = load_plain_tensors(ckpt_dir_B) # Test that A matches A diffs = diff(state_dict_A, state_dict_A_dup) assert not any(map(bool, diffs)), diffs # Test that A *keys* match B *keys*, but the tensors content is different only_left, only_right, mismatch = diff(state_dict_A, state_dict_B) assert (not only_left and not only_right), (only_left, only_right) assert len(mismatch) == len(state_dict_A), (len(mismatch), (len(state_dict_A))) Utils.destroy_model_parallel() def common_test_vocab_size_padding_change(initialize_model_fn, tmp_path_dist_ckpt, vocab_size_base, src_tp_pp, dest_tp_pp): """ Test model loading with different vocab size (caused by TP padding). """ def get_test_vocab_size(make_divisible_by=128): divisor = make_divisible_by * parallel_state.get_tensor_model_parallel_world_size() return int(math.ceil(vocab_size_base / divisor)) * divisor vocab_size_dependent_keys = { 'output_layer.weight', 'output_layer.bias', 'embedding.word_embeddings.weight', } with TempNamedDir(tmp_path_dist_ckpt / 'test_vocab_size_padding_change_A') as ckpt_dir_A, \ TempNamedDir(tmp_path_dist_ckpt / 'test_vocab_size_padding_change_B') as ckpt_dir_B: # Save checkpoint A Utils.initialize_model_parallel(*src_tp_pp) gpt_model_A = initialize_model_fn(1, tensor_model_parallel_size=src_tp_pp[0], pipeline_model_parallel_size=src_tp_pp[1], vocab_size=get_test_vocab_size()) save(gpt_model_A.sharded_state_dict(), ckpt_dir_A) Utils.destroy_model_parallel() # Load checkpoint A with different TP/PP and save as checkpoint B Utils.initialize_model_parallel(*dest_tp_pp) gpt_model_B = initialize_model_fn(2, tensor_model_parallel_size=dest_tp_pp[0], pipeline_model_parallel_size=dest_tp_pp[1], vocab_size=get_test_vocab_size()) state_dict = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A) gpt_model_B.load_state_dict(state_dict) save(gpt_model_B.sharded_state_dict(), ckpt_dir_B) Utils.destroy_model_parallel() # Test equality Utils.initialize_model_parallel(1, 1) plain_state_dict_A = load_plain_tensors(ckpt_dir_A) plain_state_dict_B = load_plain_tensors(ckpt_dir_B) # Test vocab size dependent keys are equal up to `vocab_size_base` for vocab_layer_key in vocab_size_dependent_keys: if vocab_layer_key in plain_state_dict_A: ten_A = plain_state_dict_A.pop(vocab_layer_key) ten_B = plain_state_dict_B.pop(vocab_layer_key) assert torch.all(ten_A[:vocab_size_base] == ten_B[:vocab_size_base]), vocab_layer_key # Test other tensors are equal diffs = diff(plain_state_dict_A, plain_state_dict_B) assert not any(map(bool, diffs)), diffs Utils.destroy_model_parallel()