Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -568,6 +568,39 @@ class TestSerialization:
Utils.destroy_model_parallel()
def test_empty_load(self, tmp_path_dist_ckpt):
Utils.initialize_model_parallel(2, 4)
if Utils.rank == 0:
state_dict = {'common': 'common-value'}
elif Utils.rank == 1:
state_dict = {'a': 3} # this is not saved at all (common saved by rank 0 only)
elif Utils.rank == 2:
state_dict = {'b': 3} # this is not saved at all (common saved by rank 0 only)
else:
state_dict = {
'a': ShardedTensor.from_rank_offsets(
'x', torch.ones((2,)) * Utils.rank, replica_id=Utils.rank - 3
)
}
with TempNamedDir(tmp_path_dist_ckpt / 'test_empty_load', sync=True) as ckpt_dir:
save(state_dict, ckpt_dir)
torch.distributed.barrier()
loaded_state_dict = load(state_dict, ckpt_dir)
assert loaded_state_dict['common'] == 'common-value'
if Utils.rank <= 2:
assert loaded_state_dict.keys() == {'common'}
else:
assert loaded_state_dict.keys() == {'common', 'a'}
loaded_state_dict['a'].cpu().numpy().tolist() == [
3,
3,
] # rank 3 held the main replica so did the saving
Utils.destroy_model_parallel()
class TestNonStrictLoad:
def setup_method(self, method):
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -5,7 +5,6 @@ from unittest import mock
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
......@@ -14,8 +13,9 @@ from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
......@@ -60,7 +60,7 @@ class TestMCoreEngine:
inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config)
self.mock_tokenizer = mock.Mock()
text_generation_controller = SimpleTextGenerationController(
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
)
......@@ -85,7 +85,7 @@ class TestMCoreEngine:
prompts = ["sample" * (i + 1) for i in range(self.batch_size)]
results: List[InferenceRequest] = self.mcore_engine.generate(
prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10)
prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
)
for result in results:
......@@ -110,9 +110,7 @@ class TestMCoreEngine:
prompts = ["" for i in range(self.batch_size)]
results: List[InferenceRequest] = self.mcore_engine.generate(
prompts,
add_BOS=True,
common_inference_params=CommonInferenceParams(num_tokens_to_generate=10),
prompts, add_BOS=True, sampling_params=SamplingParams(num_tokens_to_generate=10)
)
for result in results:
......
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment