Unverified Commit 5d89a0c8 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

fix: Create default sampling params only once during initialization (#982)

parent af9ee90e
...@@ -63,6 +63,7 @@ class Processor(ProcessMixIn): ...@@ -63,6 +63,7 @@ class Processor(ProcessMixIn):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "") self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config() self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args) self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config) self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor( self.completions_processor = CompletionsProcessor(
......
...@@ -29,6 +29,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat ...@@ -29,6 +29,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -38,6 +39,7 @@ class ProcessMixInRequired(Protocol): ...@@ -38,6 +39,7 @@ class ProcessMixInRequired(Protocol):
chat_processor: "ChatProcessor | None" chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None" completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig model_config: ModelConfig
default_sampling_params: SamplingParams
class ProcessMixIn(ProcessMixInRequired): class ProcessMixIn(ProcessMixInRequired):
...@@ -50,6 +52,7 @@ class ProcessMixIn(ProcessMixInRequired): ...@@ -50,6 +52,7 @@ class ProcessMixIn(ProcessMixInRequired):
chat_processor: "ChatProcessor | None" chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None" completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self): def __init__(self):
pass pass
...@@ -76,11 +79,10 @@ class ProcessMixIn(ProcessMixInRequired): ...@@ -76,11 +79,10 @@ class ProcessMixIn(ProcessMixInRequired):
default_max_tokens = self.model_config.max_model_len - len( default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"] preprocess_result.engine_prompt["prompt_token_ids"]
) )
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, default_max_tokens,
self.model_config.logits_processor_pattern, self.model_config.logits_processor_pattern,
default_sampling_params, self.default_sampling_params,
) )
return ( return (
request, request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment