Unverified Commit ec090c24 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Call renderer for online IO processor request (#34490)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent eea3024f
......@@ -500,7 +500,7 @@ class LLM:
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
......@@ -509,7 +509,7 @@ class LLM:
)
]
else:
engine_prompts = self._preprocess_completion(
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
......@@ -889,7 +889,7 @@ class LLM:
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _preprocess_completion(
def _preprocess_cmpl(
self,
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
......@@ -901,7 +901,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
renderer = self.renderer
......@@ -943,7 +943,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
A list of `TokPrompt` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
renderer = self.renderer
......@@ -1823,11 +1823,11 @@ class LLM:
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
# `self._preprocess_cmpl` handle prompt normalization
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
for engine_prompt in self._preprocess_cmpl(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
......@@ -1836,7 +1836,7 @@ class LLM:
)
]
else:
engine_prompts = self._preprocess_completion(
engine_prompts = self._preprocess_cmpl(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
......
......@@ -5,7 +5,7 @@ import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Mapping
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
......@@ -959,15 +959,22 @@ class OpenAIServing:
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts)
async def _preprocess_cmpl(
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[TokPrompt]:
renderer = self.renderer
model_config = self.model_config
parsed_prompts = [
(
prompt
......
......@@ -100,6 +100,18 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
data: T
task: PoolingTask = "plugin"
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=not model_config.is_encoder_decoder,
max_total_tokens_param="max_model_len",
)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None
......
......@@ -6,7 +6,7 @@ import json
import time
from collections.abc import AsyncGenerator, Callable, Sequence
from functools import partial
from typing import Any, Final, Literal, cast
from typing import Final, Literal, cast
import jinja2
from fastapi import Request
......@@ -108,7 +108,10 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
engine_prompts = prompt_to_seq(raw_prompts)
engine_prompts = await self._preprocess_cmpl(
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
......@@ -146,11 +149,10 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
tokenization_kwargs: dict[str, Any] = {}
else:
pooling_params = request.to_pooling_params() # type: ignore
tok_params = request.build_tok_params(self.model_config) # type: ignore
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
for i, engine_prompt in enumerate(engine_prompts):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment