Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
...@@ -2,8 +2,8 @@ from typing import Dict ...@@ -2,8 +2,8 @@ from typing import Dict
import torch import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler from megatron.core.inference.scheduler import Scheduler
...@@ -25,7 +25,7 @@ class TestScheduler: ...@@ -25,7 +25,7 @@ class TestScheduler:
def test_scheduler(self): def test_scheduler(self):
prompt = "sample prompt" prompt = "sample prompt"
prompt_tokens = torch.randn(5) prompt_tokens = torch.randn(5)
inference_parameters = CommonInferenceParams() inference_parameters = SamplingParams()
for i in range(self.max_batch_size): for i in range(self.max_batch_size):
self.scheduler.add_request(prompt, prompt_tokens, inference_parameters) self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
......
...@@ -10,7 +10,6 @@ import numpy as np ...@@ -10,7 +10,6 @@ import numpy as np
import pytest import pytest
import torch import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig, InferenceWrapperConfig,
...@@ -18,6 +17,7 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config i ...@@ -18,6 +17,7 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config i
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import ( from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import (
T5InferenceWrapper, T5InferenceWrapper,
) )
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController, EncoderDecoderTextGenerationController,
) )
...@@ -126,7 +126,7 @@ class TestEncoderDecoderTextGenerationController: ...@@ -126,7 +126,7 @@ class TestEncoderDecoderTextGenerationController:
request_id=i, request_id=i,
prompt=prompt, prompt=prompt,
encoder_prompt=encoder_prompt, encoder_prompt=encoder_prompt,
inference_parameters=CommonInferenceParams(num_tokens_to_generate=10), inference_parameters=SamplingParams(num_tokens_to_generate=10),
arrival_time=time.time(), arrival_time=time.time(),
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS,
......
...@@ -9,7 +9,6 @@ from unittest import mock ...@@ -9,7 +9,6 @@ from unittest import mock
import pytest import pytest
import torch import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper, GPTInferenceWrapper,
...@@ -17,8 +16,9 @@ from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper ...@@ -17,8 +16,9 @@ from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig, InferenceWrapperConfig,
) )
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( from megatron.core.inference.sampling_params import SamplingParams
SimpleTextGenerationController, from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
) )
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.gpt.gpt_model import GPTModel
...@@ -28,7 +28,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig ...@@ -28,7 +28,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils from tests.unit_tests.test_utilities import Utils
class TestSimpleTextGenerationController: class TestTextGenerationController:
def setup_method(self, method): def setup_method(self, method):
Utils.initialize_model_parallel( Utils.initialize_model_parallel(
...@@ -67,7 +67,7 @@ class TestSimpleTextGenerationController: ...@@ -67,7 +67,7 @@ class TestSimpleTextGenerationController:
self.mock_tokenizer = mock.Mock() self.mock_tokenizer = mock.Mock()
self.text_generation_controller = SimpleTextGenerationController( self.text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
) )
...@@ -78,7 +78,7 @@ class TestSimpleTextGenerationController: ...@@ -78,7 +78,7 @@ class TestSimpleTextGenerationController:
with pytest.raises(AssertionError) as aerror: with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits( self.text_generation_controller.sample_from_logits(
last_token_logits=None, last_token_logits=None,
common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4), sampling_params=SamplingParams(top_k=2, top_p=0.4),
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero' assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero'
...@@ -86,7 +86,7 @@ class TestSimpleTextGenerationController: ...@@ -86,7 +86,7 @@ class TestSimpleTextGenerationController:
with pytest.raises(AssertionError) as aerror: with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits( self.text_generation_controller.sample_from_logits(
last_token_logits=None, last_token_logits=None,
common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0), sampling_params=SamplingParams(top_p=1.4, top_k=0),
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
assert str(aerror.value) == 'top-p should be in (0,1]' assert str(aerror.value) == 'top-p should be in (0,1]'
...@@ -94,7 +94,7 @@ class TestSimpleTextGenerationController: ...@@ -94,7 +94,7 @@ class TestSimpleTextGenerationController:
with pytest.raises(AssertionError) as aerror: with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits( self.text_generation_controller.sample_from_logits(
last_token_logits=torch.randn(self.batch_size, 1), last_token_logits=torch.randn(self.batch_size, 1),
common_inference_params=CommonInferenceParams(top_k=self.vocab_size + 10), sampling_params=SamplingParams(top_k=self.vocab_size + 10),
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
assert str(aerror.value) == 'top-k is larger than logit size.' assert str(aerror.value) == 'top-k is larger than logit size.'
...@@ -103,14 +103,14 @@ class TestSimpleTextGenerationController: ...@@ -103,14 +103,14 @@ class TestSimpleTextGenerationController:
torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda()
) )
sampled_logits = self.text_generation_controller.sample_from_logits( sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size last_token_logits, SamplingParams(top_k=1), self.vocab_size
) )
assert torch.all( assert torch.all(
sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1 sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1
), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}" ), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}"
sampled_logits = self.text_generation_controller.sample_from_logits( sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size last_token_logits, SamplingParams(top_k=2), self.vocab_size
) )
assert torch.all( assert torch.all(
sampled_logits >= self.vocab_size - 2 sampled_logits >= self.vocab_size - 2
...@@ -120,7 +120,7 @@ class TestSimpleTextGenerationController: ...@@ -120,7 +120,7 @@ class TestSimpleTextGenerationController:
top_p = 0.3 top_p = 0.3
expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
sampled_logits = self.text_generation_controller.sample_from_logits( sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size last_token_logits, SamplingParams(top_p=top_p, top_k=0), self.vocab_size
) )
assert torch.all( assert torch.all(
sampled_logits >= expected_min_value sampled_logits >= expected_min_value
...@@ -131,7 +131,7 @@ class TestSimpleTextGenerationController: ...@@ -131,7 +131,7 @@ class TestSimpleTextGenerationController:
expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
sampled_logits = self.text_generation_controller.sample_from_logits( sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits, last_token_logits,
CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0), SamplingParams(top_p=top_p, temperature=temperature, top_k=0),
self.vocab_size, self.vocab_size,
) )
assert torch.all( assert torch.all(
...@@ -154,7 +154,7 @@ class TestSimpleTextGenerationController: ...@@ -154,7 +154,7 @@ class TestSimpleTextGenerationController:
inference_request = InferenceRequest( inference_request = InferenceRequest(
request_id=i, request_id=i,
prompt=prompt, prompt=prompt,
inference_parameters=CommonInferenceParams(num_tokens_to_generate=10), inference_parameters=SamplingParams(num_tokens_to_generate=10),
arrival_time=time.time(), arrival_time=time.time(),
prompt_tokens=torch.randint( prompt_tokens=torch.randint(
low=0, high=self.vocab_size - 1, size=(len(prompt),) low=0, high=self.vocab_size - 1, size=(len(prompt),)
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment