Unverified Commit 328cbb27 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][2/n] Make pooling entrypoints request schema consensus | ChatRequest (#32574)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 64e3d67a
...@@ -10,9 +10,9 @@ from pydantic import ( ...@@ -10,9 +10,9 @@ from pydantic import (
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin, CompletionRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
...@@ -45,48 +45,8 @@ class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionReques ...@@ -45,48 +45,8 @@ class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionReques
) )
class ClassificationChatRequest(PoolingBasicRequestMixin): class ClassificationChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
messages: list[ChatCompletionMessageParam]
# --8<-- [start:chat-classification-extra-params] # --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
......
...@@ -86,8 +86,8 @@ class ClassificationMixin(OpenAIServing): ...@@ -86,8 +86,8 @@ class ClassificationMixin(OpenAIServing):
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"), getattr(self, "chat_template_content_format", "auto"),
), ),
add_generation_prompt=False, add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=False, continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens, add_special_tokens=chat_request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts ctx.engine_prompts = engine_prompts
......
...@@ -5,13 +5,12 @@ from typing import Any, TypeAlias ...@@ -5,13 +5,12 @@ from typing import Any, TypeAlias
from pydantic import ( from pydantic import (
Field, Field,
model_validator,
) )
from vllm import PoolingParams from vllm import PoolingParams
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin, CompletionRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
...@@ -57,57 +56,11 @@ class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixi ...@@ -57,57 +56,11 @@ class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixi
) )
class EmbeddingChatRequest(PoolingBasicRequestMixin): class EmbeddingChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
messages: list[ChatCompletionMessageParam]
encoding_format: EncodingFormat = "float" encoding_format: EncodingFormat = "float"
dimensions: int | None = None dimensions: int | None = None
# --8<-- [start:chat-embedding-extra-params] # --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
...@@ -134,16 +87,6 @@ class EmbeddingChatRequest(PoolingBasicRequestMixin): ...@@ -134,16 +87,6 @@ class EmbeddingChatRequest(PoolingBasicRequestMixin):
) )
# --8<-- [end:chat-embedding-extra-params] # --8<-- [end:chat-embedding-extra-params]
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
......
...@@ -144,10 +144,8 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -144,10 +144,8 @@ class OpenAIServingPooling(OpenAIServing):
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format, chat_template_content_format=self.chat_template_content_format,
# In pooling requests, we are not generating tokens, add_generation_prompt=request.add_generation_prompt,
# so there is no need to append extra tokens to the input continue_final_message=request.continue_final_message,
add_generation_prompt=False,
continue_final_message=False,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment