# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import io from contextlib import nullcontext import numpy as np import pytest import torch from torch.distributed.checkpoint import CheckpointException from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor, load, save from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config from megatron.core.dist_checkpointing.dict_utils import diff from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory from megatron.core.dist_checkpointing.serialization import load_tensors_metadata from megatron.core.dist_checkpointing.strategies.resharding import ( apply_nd_flattened_tensors_reformulation, restore_nd_flattened_tensors_formulation, ) from megatron.core.dist_checkpointing.strategies.torch import get_reformulation_metadata from megatron.core.dist_checkpointing.validation import ( determine_global_metadata, validate_sharding_integrity, ) from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils class TestFlattenedResharding: def setup_method(self, method): pass def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.parametrize( ('src_tp_pp', 'dest_tp_pp'), [((2, 4), (2, 4)), ((2, 4), (2, 2)), ((2, 4), (4, 2)), ((8, 1), (1, 2))], ) def test_partition_change_save_load(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp): Utils.initialize_model_parallel(*src_tp_pp) with TempNamedDir( tmp_path_dist_ckpt / 'test_flattened_partition_change_save_load' ) as ckpt_dir: state_dict = self._build_state_dict() save(state_dict, ckpt_dir) # change TPxPP Utils.destroy_model_parallel() Utils.initialize_model_parallel(*dest_tp_pp) loaded_state_dict = load(self._build_state_dict(random=True), ckpt_dir) expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()} diffs = diff(expected_state_dict, loaded_state_dict) assert not any(diffs), diffs Utils.destroy_model_parallel() @pytest.mark.parametrize( ('src_tp_pp', 'dest_tp_pp', 'expected_ckpt_offsets_by_rank'), [ ( (2, 4), (2, 2), { 0: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 0, PP 0 1: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 0, PP 0 2: [(0, 0, 0), (0, 0, 10)], # TP 0, DP 1, PP 0 3: [(4, 0, 0), (4, 0, 10)], # TP 1, DP 1, PP 0 4: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 0, PP 1 5: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 0, PP 1 6: [(0, 0, 20), (0, 0, 30)], # TP 0, DP 1, PP 1 7: [(4, 0, 20), (4, 0, 30)], # TP 1, DP 1, PP 1 }, ), ((8, 1), (1, 2), {rank: [(tp, 0, 0) for tp in range(8)] for rank in range(8)}), ], ) def test_reformulate_nd_flattened_tensors( self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp, expected_ckpt_offsets_by_rank ): Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp') with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir: state_dict = self._build_state_dict() ckpt_local_shape = state_dict['sd_key_flat'].local_shape save(state_dict, ckpt_dir) # change TPxPP Utils.destroy_model_parallel() Utils.initialize_model_parallel(*dest_tp_pp, order='tp-dp-pp') load_state_dict = self._build_state_dict(random=True) reformulation_metadata = get_reformulation_metadata(load_state_dict, ckpt_dir) reformulated_state_dict, formulation_restore_data = ( apply_nd_flattened_tensors_reformulation(load_state_dict, reformulation_metadata) ) assert isinstance(reformulated_state_dict['sd_key_unflat'], ShardedTensor) assert isinstance(reformulated_state_dict['sd_key_flat'], dict) assert reformulated_state_dict['sd_key_flat'].keys() == set( (offset, ckpt_local_shape) for offset in expected_ckpt_offsets_by_rank[Utils.rank] ), ( reformulated_state_dict['sd_key_flat'].keys(), ckpt_local_shape, expected_ckpt_offsets_by_rank[Utils.rank], ) # We can even load the reformulated state dict with a high-level API loaded_state_dict = load( reformulated_state_dict, ckpt_dir, validate_access_integrity=False ) loaded_state_dict = restore_nd_flattened_tensors_formulation( loaded_state_dict, formulation_restore_data ) expected_state_dict = {k: v.data for k, v in self._build_state_dict().items()} diffs = diff(expected_state_dict, loaded_state_dict) assert not any(diffs), diffs Utils.destroy_model_parallel() @pytest.mark.parametrize(('src_tp_pp',), [((2, 4),), ((8, 1),), ((1, 1),), ((1, 4),)]) def test_load_tensor_metadata(self, tmp_path_dist_ckpt, src_tp_pp): Utils.initialize_model_parallel(*src_tp_pp, order='tp-dp-pp') with TempNamedDir(tmp_path_dist_ckpt / 'test_reformulate_nd_flattened_tensors') as ckpt_dir: state_dict = self._build_state_dict() save(state_dict, ckpt_dir) # change TPxPP Utils.destroy_model_parallel() Utils.initialize_model_parallel(1, 1) sharded_metadata = load_tensors_metadata(ckpt_dir) for attr_name in ('local_shape', 'global_shape'): flat_val = getattr(sharded_metadata['flat'], attr_name) unflat_val = getattr(sharded_metadata['unflat'], attr_name) assert flat_val == unflat_val, (attr_name, flat_val, unflat_val) for sh_ten in sharded_metadata.values(): sh_ten.replica_id = Utils.rank loaded_state_dict = load(sharded_metadata, ckpt_dir) assert torch.all( loaded_state_dict['unflat'] == torch.arange(8 * 5 * 40).reshape(8, 5, 40) ) assert torch.all(loaded_state_dict['flat'] == torch.arange(8 * 5 * 40)) Utils.destroy_model_parallel() def _build_state_dict(self, random=False): tp_rank = parallel_state.get_tensor_model_parallel_rank() tp_size = parallel_state.get_tensor_model_parallel_world_size() pp_rank = parallel_state.get_pipeline_model_parallel_rank() pp_size = parallel_state.get_pipeline_model_parallel_world_size() dp_rank = parallel_state.get_data_parallel_rank() dp_size = parallel_state.get_data_parallel_world_size() init_fn = torch.rand if random else torch.arange global_ten = init_fn(8 * 5 * 40).reshape(8, 5, 40) local_ten = global_ten local_ten = local_ten.chunk(tp_size, dim=0)[tp_rank] local_ten = local_ten.chunk(pp_size, dim=2)[pp_rank] assert local_ten.shape == (8 // tp_size, 5, 40 // pp_size) local_ten_size_by_dp = local_ten.numel() assert local_ten_size_by_dp % dp_size == 0, (local_ten_size_by_dp, dp_size) local_ten_size_by_dp = local_ten_size_by_dp // dp_size # make a bit shifted DP slices so that they are not equal start_jitter = dp_rank end_jitter = dp_rank + 1 if dp_rank + 1 < dp_size else 0 local_dp_slice = slice( local_ten_size_by_dp * dp_rank + start_jitter, local_ten_size_by_dp * (dp_rank + 1) + end_jitter, ) local_flat_ten = local_ten.flatten()[local_dp_slice] if dp_rank == dp_size - 1: assert local_flat_ten.numel() == local_ten_size_by_dp - dp_rank else: assert local_flat_ten.numel() == local_ten_size_by_dp + 1 state_dict = { 'sd_key_unflat': ShardedTensor.from_rank_offsets( 'unflat', local_ten, (0, tp_rank, tp_size), (2, pp_rank, pp_size), replica_id=dp_rank, ), 'sd_key_flat': ShardedTensor.from_rank_offsets_flat( 'flat', local_flat_ten, local_ten.shape, (0, tp_rank, tp_size), (2, pp_rank, pp_size), flattened_range=local_dp_slice, ), } return state_dict def test_flattened_tensors_are_properly_validated(self, tmp_path_dist_ckpt): Utils.initialize_model_parallel() # Global tensor of shape (6, 6) is built from: # ranks 0, 1, 2 tensors of length 1, 2, 3 # and then ranks 3, ..., 7 tensors of length 6 local_flat_ten = torch.ones(Utils.rank + 1 if Utils.rank <= 2 else 6) * Utils.rank global_flattened_len = 6 + (Utils.world_size - 3) * 6 if Utils.world_size == 8: assert global_flattened_len == 1 + 2 + 3 + 5 * 6 local_ten_shape = (1, 6) else: local_ten_shape = (global_flattened_len,) if Utils.rank == 0: local_dp_slice_start = 0 elif Utils.rank == 1: local_dp_slice_start = 1 elif Utils.rank == 2: local_dp_slice_start = 3 else: local_dp_slice_start = 0 local_dp_slice = slice(local_dp_slice_start, local_dp_slice_start + len(local_flat_ten)) state_dict = { 'sd_key_flat': ShardedTensor.from_rank_offsets_flat( 'flat', local_flat_ten, local_ten_shape, *((0, max(0, Utils.rank - 2), 6),) if Utils.world_size == 8 else (), flattened_range=local_dp_slice, replica_id=0 ) } validate_sharding_integrity(determine_global_metadata(state_dict)[1]) if Utils.rank == 1: old_state_dict = state_dict state_dict = {} with ( pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext() ) as exc_info: validate_sharding_integrity(determine_global_metadata(state_dict)[1]) if Utils.rank == 0: assert 'Flattened ranges dont cover the whole shard ShardedTensor' in str( exc_info.value ) if Utils.rank == 1: state_dict = old_state_dict if Utils.rank == 4: state_dict = {} with ( pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext() ) as exc_info: validate_sharding_integrity(determine_global_metadata(state_dict)[1]) if Utils.rank == 0: assert 'Invalid access pattern' in str(exc_info.value) Utils.destroy_model_parallel()