Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -5,7 +5,6 @@ from unittest import mock
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
......@@ -14,8 +13,9 @@ from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
......@@ -60,7 +60,7 @@ class TestMCoreEngine:
inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config)
self.mock_tokenizer = mock.Mock()
text_generation_controller = SimpleTextGenerationController(
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
)
......@@ -85,7 +85,7 @@ class TestMCoreEngine:
prompts = ["sample" * (i + 1) for i in range(self.batch_size)]
results: List[InferenceRequest] = self.mcore_engine.generate(
prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10)
prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
)
for result in results:
......@@ -110,9 +110,7 @@ class TestMCoreEngine:
prompts = ["" for i in range(self.batch_size)]
results: List[InferenceRequest] = self.mcore_engine.generate(
prompts,
add_BOS=True,
common_inference_params=CommonInferenceParams(num_tokens_to_generate=10),
prompts, add_BOS=True, sampling_params=SamplingParams(num_tokens_to_generate=10)
)
for result in results:
......
File mode changed from 100755 to 100644
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams
class TestCommonInferenceParams:
class TestSamplingParams:
def test_inference_params(self):
inference_parameters = CommonInferenceParams()
inference_parameters = SamplingParams()
inference_parameters.add_attributes({"min_tokens": 45})
assert (
inference_parameters.min_tokens == 45
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment