Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams
class TestCommonInferenceParams:
class TestSamplingParams:
def test_inference_params(self):
inference_parameters = CommonInferenceParams()
inference_parameters = SamplingParams()
inference_parameters.add_attributes({"min_tokens": 45})
assert (
inference_parameters.min_tokens == 45
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -2,8 +2,8 @@ from typing import Dict
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler
......@@ -25,7 +25,7 @@ class TestScheduler:
def test_scheduler(self):
prompt = "sample prompt"
prompt_tokens = torch.randn(5)
inference_parameters = CommonInferenceParams()
inference_parameters = SamplingParams()
for i in range(self.max_batch_size):
self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
......
......@@ -10,7 +10,6 @@ import numpy as np
import pytest
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
......@@ -18,6 +17,7 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config i
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import (
T5InferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController,
)
......@@ -126,7 +126,7 @@ class TestEncoderDecoderTextGenerationController:
request_id=i,
prompt=prompt,
encoder_prompt=encoder_prompt,
inference_parameters=CommonInferenceParams(num_tokens_to_generate=10),
inference_parameters=SamplingParams(num_tokens_to_generate=10),
arrival_time=time.time(),
prompt_tokens=prompt_tokens,
status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS,
......
......@@ -9,7 +9,6 @@ from unittest import mock
import pytest
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
......@@ -17,8 +16,9 @@ from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
......@@ -28,7 +28,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils
class TestSimpleTextGenerationController:
class TestTextGenerationController:
def setup_method(self, method):
Utils.initialize_model_parallel(
......@@ -67,7 +67,7 @@ class TestSimpleTextGenerationController:
self.mock_tokenizer = mock.Mock()
self.text_generation_controller = SimpleTextGenerationController(
self.text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
)
......@@ -78,7 +78,7 @@ class TestSimpleTextGenerationController:
with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4),
sampling_params=SamplingParams(top_k=2, top_p=0.4),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero'
......@@ -86,7 +86,7 @@ class TestSimpleTextGenerationController:
with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0),
sampling_params=SamplingParams(top_p=1.4, top_k=0),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'top-p should be in (0,1]'
......@@ -94,7 +94,7 @@ class TestSimpleTextGenerationController:
with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=torch.randn(self.batch_size, 1),
common_inference_params=CommonInferenceParams(top_k=self.vocab_size + 10),
sampling_params=SamplingParams(top_k=self.vocab_size + 10),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'top-k is larger than logit size.'
......@@ -103,14 +103,14 @@ class TestSimpleTextGenerationController:
torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda()
)
sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size
last_token_logits, SamplingParams(top_k=1), self.vocab_size
)
assert torch.all(
sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1
), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}"
sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size
last_token_logits, SamplingParams(top_k=2), self.vocab_size
)
assert torch.all(
sampled_logits >= self.vocab_size - 2
......@@ -120,7 +120,7 @@ class TestSimpleTextGenerationController:
top_p = 0.3
expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size
last_token_logits, SamplingParams(top_p=top_p, top_k=0), self.vocab_size
)
assert torch.all(
sampled_logits >= expected_min_value
......@@ -131,7 +131,7 @@ class TestSimpleTextGenerationController:
expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits,
CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0),
SamplingParams(top_p=top_p, temperature=temperature, top_k=0),
self.vocab_size,
)
assert torch.all(
......@@ -154,7 +154,7 @@ class TestSimpleTextGenerationController:
inference_request = InferenceRequest(
request_id=i,
prompt=prompt,
inference_parameters=CommonInferenceParams(num_tokens_to_generate=10),
inference_parameters=SamplingParams(num_tokens_to_generate=10),
arrival_time=time.time(),
prompt_tokens=torch.randint(
low=0, high=self.vocab_size - 1, size=(len(prompt),)
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment