Unverified Commit edb359cc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Define `render_cmpl` and `render_chat` (#34039)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6ed5eda3
...@@ -170,8 +170,6 @@ def run_test( ...@@ -170,8 +170,6 @@ def run_test(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
[ [
# No image
[],
# Single-scale # Single-scale
[1.0], [1.0],
# Single-scale, batched # Single-scale, batched
......
...@@ -375,7 +375,6 @@ def test_qwen2_vl_image_embeddings_input( ...@@ -375,7 +375,6 @@ def test_qwen2_vl_image_embeddings_input(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
[ [
[],
# Single-scale # Single-scale
[0.5], [0.5],
# Single-scale, batched # Single-scale, batched
......
...@@ -100,8 +100,6 @@ def run_awq_test( ...@@ -100,8 +100,6 @@ def run_awq_test(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
[ [
# No image
[],
# Single-scale # Single-scale
[1.0], [1.0],
# Single-scale, batched # Single-scale, batched
......
...@@ -73,7 +73,7 @@ from vllm.outputs import ( ...@@ -73,7 +73,7 @@ from vllm.outputs import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, SingletonDictPrompt, TokPrompt from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
conversation_to_seq, conversation_to_seq,
extract_prompt_components, extract_prompt_components,
...@@ -805,7 +805,7 @@ class LLM: ...@@ -805,7 +805,7 @@ class LLM:
self, self,
prompts: Sequence[PromptType], prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[DictPrompt | TokPrompt]: ) -> Sequence[DictPrompt | TokPrompt]:
""" """
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`. a format that can be passed to `_add_request`.
...@@ -819,22 +819,12 @@ class LLM: ...@@ -819,22 +819,12 @@ class LLM:
renderer = self.llm_engine.renderer renderer = self.llm_engine.renderer
model_config = self.model_config model_config = self.model_config
parsed_prompts = [
parse_model_prompt(model_config, prompt) for prompt in prompts
]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs) tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[DictPrompt | TokPrompt]() return renderer.render_cmpl(parsed_prompts, tok_params)
for prompt in prompts:
parsed_prompt = parse_model_prompt(model_config, prompt)
in_prompt = renderer.render_prompt(parsed_prompt)
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
in_prompt
if model_config.is_multimodal_model
else renderer.tokenize_prompt(in_prompt, tok_params)
)
return engine_prompts
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None): def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config model_config = self.model_config
...@@ -857,7 +847,7 @@ class LLM: ...@@ -857,7 +847,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[DictPrompt | TokPrompt]: ) -> Sequence[TokPrompt]:
""" """
Convert a list of conversations into prompts so that they can then Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs. be used as input for other LLM APIs.
...@@ -885,16 +875,12 @@ class LLM: ...@@ -885,16 +875,12 @@ class LLM:
) )
tok_params = self._get_chat_tok_params(tokenization_kwargs) tok_params = self._get_chat_tok_params(tokenization_kwargs)
engine_prompts = list[DictPrompt | TokPrompt]() _, engine_prompts = renderer.render_chat(
for conversation in conversations: conversations,
_, in_prompt = renderer.render_messages(conversation, chat_params) chat_params,
if mm_processor_kwargs is not None: tok_params,
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
"encoder_prompt", in_prompt
) )
target_prompt["mm_processor_kwargs"] = mm_processor_kwargs # type: ignore
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
return engine_prompts return engine_prompts
...@@ -1743,7 +1729,7 @@ class LLM: ...@@ -1743,7 +1729,7 @@ class LLM:
# TODO: Remove this after deprecating `param.truncate_prompt_tokens` # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let # Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization # `self._preprocess_completion` handle prompt normalization
engine_prompts = [ engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt engine_prompt
for prompt, param in zip(seq_prompts, seq_params) for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion( for engine_prompt in self._preprocess_completion(
......
...@@ -106,7 +106,6 @@ from vllm.pooling_params import PoolingParams ...@@ -106,7 +106,6 @@ from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
SingletonDictPrompt,
extract_prompt_components, extract_prompt_components,
extract_prompt_len, extract_prompt_len,
parse_model_prompt, parse_model_prompt,
...@@ -963,8 +962,6 @@ class OpenAIServing: ...@@ -963,8 +962,6 @@ class OpenAIServing:
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
tok_params = request.build_tok_params(model_config)
prompts = list[SingletonPrompt | bytes]() prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds)) prompts.extend(prompt_to_seq(prompt_embeds))
...@@ -979,22 +976,17 @@ class OpenAIServing: ...@@ -979,22 +976,17 @@ class OpenAIServing:
) )
for prompt in prompts for prompt in prompts
] ]
in_prompts = await renderer.render_prompts_async(parsed_prompts) tok_params = request.build_tok_params(model_config)
extra_items = { return await renderer.render_cmpl_async(
parsed_prompts,
tok_params,
prompt_extras={
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
} },
for in_prompt in in_prompts:
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
) )
target_prompt.update(extra_items) # type: ignore
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
return engine_prompts
async def _preprocess_chat( async def _preprocess_chat(
self, self,
...@@ -1023,21 +1015,16 @@ class OpenAIServing: ...@@ -1023,21 +1015,16 @@ class OpenAIServing:
default_template, default_template_content_format default_template, default_template_content_format
).with_defaults(default_template_kwargs) ).with_defaults(default_template_kwargs)
conversation, in_prompt = await renderer.render_messages_async( (conversation,), (engine_prompt,) = await renderer.render_chat_async(
messages, chat_params [messages],
) chat_params,
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore tok_params,
"encoder_prompt", in_prompt prompt_extras={
)
extra_items = {
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
} },
target_prompt.update(extra_items) # type: ignore )
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
......
...@@ -225,16 +225,20 @@ class PromptComponents(NamedTuple): ...@@ -225,16 +225,20 @@ class PromptComponents(NamedTuple):
embeds: "torch.Tensor | None" = None embeds: "torch.Tensor | None" = None
def extract_prompt_components( def extract_target_prompt(model_config: "ModelConfig", prompt: object):
model_config: "ModelConfig", return (
prompt: object,
) -> PromptComponents:
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"] parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt) else parse_dec_only_prompt(prompt)
) )
def extract_prompt_components(
model_config: "ModelConfig",
prompt: object,
) -> PromptComponents:
target_prompt = extract_target_prompt(model_config, prompt)
return PromptComponents( return PromptComponents(
text=target_prompt.get("prompt"), text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type] token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
...@@ -243,11 +247,7 @@ def extract_prompt_components( ...@@ -243,11 +247,7 @@ def extract_prompt_components(
def extract_prompt_len(model_config: "ModelConfig", prompt: object): def extract_prompt_len(model_config: "ModelConfig", prompt: object):
target_prompt = ( target_prompt = extract_target_prompt(model_config, prompt)
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return length_from_prompt_token_ids_or_embeds( return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type] target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
......
...@@ -16,6 +16,7 @@ from .inputs import ( ...@@ -16,6 +16,7 @@ from .inputs import (
EncoderDecoderTokPrompt, EncoderDecoderTokPrompt,
TokPrompt, TokPrompt,
) )
from .inputs.preprocess import extract_target_prompt
from .params import ChatParams, TokenizeParams from .params import ChatParams, TokenizeParams
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -277,3 +278,109 @@ class BaseRenderer(ABC): ...@@ -277,3 +278,109 @@ class BaseRenderer(ABC):
return await asyncio.gather( return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts) *(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
) )
# Step 3: Add extra keys to the prompts
def _apply_prompt_extras(
self,
prompts: Sequence[DictPrompt | TokPrompt],
prompt_extras: dict[str, Any] | None,
):
if not prompt_extras:
return
for prompt in prompts:
target_prompt = extract_target_prompt(self.config, prompt)
target_prompt.update(prompt_extras) # type: ignore[arg-type]
# Top-level methods
def render_cmpl(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams,
*,
prompt_extras: dict[str, Any] | None = None,
):
dict_prompts = self.render_prompts(prompts)
# NOTE: Some MM models have non-default `add_special_tokens`
# so we handle tokenization in multi-modal processor
if self.config.is_multimodal_model:
self._apply_prompt_extras(dict_prompts, prompt_extras)
return dict_prompts
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return tok_prompts
async def render_cmpl_async(
self,
prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams,
*,
prompt_extras: dict[str, Any] | None = None,
):
dict_prompts = await self.render_prompts_async(prompts)
# NOTE: MM data cannot be passed to online Completions API
# so we don't have the special case that is in the offline version
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return tok_prompts
def render_chat(
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams,
*,
prompt_extras: dict[str, Any] | None = None,
):
rendered = [
self.render_messages(conversation, chat_params)
for conversation in conversations
]
out_conversations = list[list["ConversationMessage"]]()
dict_prompts = list[DictPrompt]()
for conv, prompt in rendered:
out_conversations.append(conv)
dict_prompts.append(prompt)
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return out_conversations, tok_prompts
async def render_chat_async(
self,
conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams,
tok_params: TokenizeParams,
*,
prompt_extras: dict[str, Any] | None = None,
):
rendered = [
self.render_messages_async(conversation, chat_params)
for conversation in conversations
]
out_conversations = list[list["ConversationMessage"]]()
dict_prompts = list[DictPrompt]()
for conv, prompt in await asyncio.gather(*rendered):
out_conversations.append(conv)
dict_prompts.append(prompt)
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras)
# TODO: Apply multi-modal processor
return out_conversations, tok_prompts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment