Unverified Commit ec090c24 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Call renderer for online IO processor request (#34490)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent eea3024f
...@@ -500,7 +500,7 @@ class LLM: ...@@ -500,7 +500,7 @@ class LLM:
engine_prompts: Sequence[DictPrompt | TokPrompt] = [ engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt engine_prompt
for prompt, param in zip(seq_prompts, seq_params) for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion( for engine_prompt in self._preprocess_cmpl(
[prompt], [prompt],
tokenization_kwargs=merge_kwargs( tokenization_kwargs=merge_kwargs(
tokenization_kwargs, tokenization_kwargs,
...@@ -509,7 +509,7 @@ class LLM: ...@@ -509,7 +509,7 @@ class LLM:
) )
] ]
else: else:
engine_prompts = self._preprocess_completion( engine_prompts = self._preprocess_cmpl(
seq_prompts, seq_prompts,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
...@@ -889,7 +889,7 @@ class LLM: ...@@ -889,7 +889,7 @@ class LLM:
add_special_tokens=not model_config.is_encoder_decoder, add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs) ).with_kwargs(tokenization_kwargs)
def _preprocess_completion( def _preprocess_cmpl(
self, self,
prompts: Sequence[PromptType], prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
...@@ -901,7 +901,7 @@ class LLM: ...@@ -901,7 +901,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments. Refer to [LLM.generate][] for a complete description of the arguments.
Returns: Returns:
A list of `TokensPrompts` objects containing the tokenized prompt A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs. after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.renderer renderer = self.renderer
...@@ -943,7 +943,7 @@ class LLM: ...@@ -943,7 +943,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments. Refer to [LLM.chat][] for a complete description of the arguments.
Returns: Returns:
A list of `TokensPrompts` objects containing the tokenized prompt A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs. after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.renderer renderer = self.renderer
...@@ -1823,11 +1823,11 @@ class LLM: ...@@ -1823,11 +1823,11 @@ class LLM:
if any(param.truncate_prompt_tokens is not None for param in seq_params): if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens` # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let # Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization # `self._preprocess_cmpl` handle prompt normalization
engine_prompts: Sequence[DictPrompt | TokPrompt] = [ engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt engine_prompt
for prompt, param in zip(seq_prompts, seq_params) for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion( for engine_prompt in self._preprocess_cmpl(
[prompt], [prompt],
tokenization_kwargs=merge_kwargs( tokenization_kwargs=merge_kwargs(
tokenization_kwargs, tokenization_kwargs,
...@@ -1836,7 +1836,7 @@ class LLM: ...@@ -1836,7 +1836,7 @@ class LLM:
) )
] ]
else: else:
engine_prompts = self._preprocess_completion( engine_prompts = self._preprocess_cmpl(
seq_prompts, seq_prompts,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
......
...@@ -5,7 +5,7 @@ import json ...@@ -5,7 +5,7 @@ import json
import sys import sys
import time import time
import traceback import traceback
from collections.abc import AsyncGenerator, Callable, Mapping from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
...@@ -959,15 +959,22 @@ class OpenAIServing: ...@@ -959,15 +959,22 @@ class OpenAIServing:
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]: ) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]() prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds)) prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None: if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input)) prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts)
async def _preprocess_cmpl(
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
parsed_prompts = [ parsed_prompts = [
( (
prompt prompt
......
...@@ -100,6 +100,18 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic ...@@ -100,6 +100,18 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
data: T data: T
task: PoolingTask = "plugin" task: PoolingTask = "plugin"
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=not model_config.is_encoder_decoder,
max_total_tokens_param="max_model_len",
)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]): class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None request_id: str | None = None
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
import time import time
from collections.abc import AsyncGenerator, Callable, Sequence from collections.abc import AsyncGenerator, Callable, Sequence
from functools import partial from functools import partial
from typing import Any, Final, Literal, cast from typing import Final, Literal, cast
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
...@@ -108,7 +108,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -108,7 +108,10 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
) )
engine_prompts = prompt_to_seq(raw_prompts) engine_prompts = await self._preprocess_cmpl(
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template( error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
...@@ -146,11 +149,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -146,11 +149,10 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params = self.io_processor.merge_pooling_params() pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None: if pooling_params.task is None:
pooling_params.task = "plugin" pooling_params.task = "plugin"
tokenization_kwargs: dict[str, Any] = {}
else: else:
pooling_params = request.to_pooling_params() # type: ignore pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config) # type: ignore
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs() tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment