Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -15,7 +15,10 @@ from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.transformer_config import TransformerConfig
......@@ -43,22 +46,25 @@ def initialize_expert_layer(seed, glu=True, expert_type='sequential', fp8=False,
)
default_config_kwargs.update(**config_kwargs)
transformer_config = TransformerConfig(**default_config_kwargs)
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=num_moe_experts, moe_grouped_gemm=(expert_type != 'sequential'), fp8=fp8
)
if expert_type == 'grouped':
model = GroupedMLP(num_local_experts, transformer_config)
elif expert_type == 'te_grouped':
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=num_moe_experts, moe_grouped_gemm=True
)
model = TEGroupedMLP(
num_local_experts,
transformer_config,
transformer_layer_spec.submodules.mlp.submodules.experts,
transformer_layer_spec.submodules.mlp.submodules.experts.submodules,
)
elif expert_type == 'sequential':
transformer_layer_spec = get_gpt_layer_local_spec(
num_experts=num_moe_experts, moe_grouped_gemm=False
)
model = SequentialMLP(
num_local_experts,
transformer_config,
transformer_layer_spec.submodules.mlp.submodules.experts,
transformer_layer_spec.submodules.mlp.submodules.experts.submodules,
)
else:
raise ValueError('expert_type can only be one of ["sequential", "grouped", "te_grouped"]')
......@@ -86,6 +92,7 @@ class TestExpertLayerReconfiguration:
def teardown_method(self, method):
Utils.destroy_model_parallel()
@pytest.mark.internal
@pytest.mark.parametrize(
"use_fpsl,src_tp_pp_ep_etp,dest_tp_pp_ep_etp,use_glu",
[
......@@ -200,6 +207,7 @@ class TestExpertLayerReconfiguration:
diffs = diff(state_dict_A, state_dict_B)
assert not any(map(bool, diffs)), diffs
@pytest.mark.internal
@pytest.mark.parametrize(
"src_tp_pp_exp,dest_tp_pp_exp,use_glu",
[
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import io
from contextlib import nullcontext
import numpy as np
import pytest
......@@ -18,6 +19,10 @@ from megatron.core.dist_checkpointing.strategies.resharding import (
restore_nd_flattened_tensors_formulation,
)
from megatron.core.dist_checkpointing.strategies.torch import get_reformulation_metadata
from megatron.core.dist_checkpointing.validation import (
determine_global_metadata,
validate_sharding_integrity,
)
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils
......@@ -198,3 +203,66 @@ class TestFlattenedResharding:
),
}
return state_dict
def test_flattened_tensors_are_properly_validated(self, tmp_path_dist_ckpt):
Utils.initialize_model_parallel()
# Global tensor of shape (6, 6) is built from:
# ranks 0, 1, 2 tensors of length 1, 2, 3
# and then ranks 3, ..., 7 tensors of length 6
local_flat_ten = torch.ones(Utils.rank + 1 if Utils.rank <= 2 else 6) * Utils.rank
global_flattened_len = 6 + (Utils.world_size - 3) * 6
if Utils.world_size == 8:
assert global_flattened_len == 1 + 2 + 3 + 5 * 6
local_ten_shape = (1, 6)
else:
local_ten_shape = (global_flattened_len,)
if Utils.rank == 0:
local_dp_slice_start = 0
elif Utils.rank == 1:
local_dp_slice_start = 1
elif Utils.rank == 2:
local_dp_slice_start = 3
else:
local_dp_slice_start = 0
local_dp_slice = slice(local_dp_slice_start, local_dp_slice_start + len(local_flat_ten))
state_dict = {
'sd_key_flat': ShardedTensor.from_rank_offsets_flat(
'flat',
local_flat_ten,
local_ten_shape,
*((0, max(0, Utils.rank - 2), 6),) if Utils.world_size == 8 else (),
flattened_range=local_dp_slice,
replica_id=0
)
}
validate_sharding_integrity(determine_global_metadata(state_dict)[1])
if Utils.rank == 1:
old_state_dict = state_dict
state_dict = {}
with (
pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext()
) as exc_info:
validate_sharding_integrity(determine_global_metadata(state_dict)[1])
if Utils.rank == 0:
assert 'Flattened ranges dont cover the whole shard ShardedTensor' in str(
exc_info.value
)
if Utils.rank == 1:
state_dict = old_state_dict
if Utils.rank == 4:
state_dict = {}
with (
pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext()
) as exc_info:
validate_sharding_integrity(determine_global_metadata(state_dict)[1])
if Utils.rank == 0:
assert 'Invalid access pattern' in str(exc_info.value)
Utils.destroy_model_parallel()
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment