Unverified Commit 4e8c3f1c authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][last/5] Improve pooling entrypoints | clean up. (#39675)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 5e5afafa
...@@ -4,12 +4,12 @@ from collections.abc import Sequence ...@@ -4,12 +4,12 @@ from collections.abc import Sequence
from typing import Any from typing import Any
from vllm import PoolingParams, PoolingRequestOutput from vllm import PoolingParams, PoolingRequestOutput
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.inputs import EngineInput from vllm.inputs import EngineInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from ..base.io_processor import PoolingIOProcessor
from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext
from .protocol import IOProcessorRequest, IOProcessorResponse from .protocol import IOProcessorRequest, IOProcessorResponse
...@@ -17,6 +17,8 @@ logger = init_logger(__name__) ...@@ -17,6 +17,8 @@ logger = init_logger(__name__)
class PluginWithoutIOProcessorPlugins(PoolingIOProcessor): class PluginWithoutIOProcessorPlugins(PoolingIOProcessor):
# Some models, such as Terratorch (tests/models/test_terratorch.py),
# use plugin tasks in the pooler but do not use IO Processor plugins.
name = "plugin" name = "plugin"
......
...@@ -8,7 +8,11 @@ from pydantic import Field ...@@ -8,7 +8,11 @@ from pydantic import Field
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
from ..base.protocol import (
ChatRequestMixin, ChatRequestMixin,
ClassifyRequestMixin, ClassifyRequestMixin,
CompletionRequestMixin, CompletionRequestMixin,
...@@ -16,9 +20,6 @@ from vllm.entrypoints.pooling.base.protocol import ( ...@@ -16,9 +20,6 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin, EncodingRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
class PoolingCompletionRequest( class PoolingCompletionRequest(
......
...@@ -11,27 +11,28 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse ...@@ -11,27 +11,28 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.logger import init_logger
from vllm.entrypoints.pooling.base.serving import PoolingServingBase from vllm.outputs import PoolingRequestOutput
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.tasks import SupportedTask
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.utils.serial_utils import EmbedDType, Endianness
from ..base.io_processor import PoolingIOProcessor
from ..base.serving import PoolingServingBase
from ..factories import init_pooling_io_processors
from ..typing import AnyPoolingRequest, PoolingServeContext
from ..utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
get_json_response_cls,
)
from .protocol import (
IOProcessorRequest, IOProcessorRequest,
PoolingBytesResponse, PoolingBytesResponse,
PoolingRequest, PoolingRequest,
PoolingResponse, PoolingResponse,
PoolingResponseData, PoolingResponseData,
) )
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
get_json_response_cls,
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,12 +7,6 @@ from typing import Any, TypeAlias ...@@ -7,12 +7,6 @@ from typing import Any, TypeAlias
import torch.nn.functional as F import torch.nn.functional as F
from vllm import PoolingParams, PoolingRequestOutput, TokensPrompt from vllm import PoolingParams, PoolingRequestOutput, TokensPrompt
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingServeContext,
)
from vllm.inputs import EngineInput from vllm.inputs import EngineInput
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.renderers.hf import safe_apply_chat_template from vllm.renderers.hf import safe_apply_chat_template
...@@ -20,6 +14,12 @@ from vllm.tasks import PoolingTask ...@@ -20,6 +14,12 @@ from vllm.tasks import PoolingTask
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
from ...chat_utils import ChatTemplateResolutionError from ...chat_utils import ChatTemplateResolutionError
from ..base.io_processor import PoolingIOProcessor
from ..typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingServeContext,
)
from .protocol import RerankRequest, ScoreRequest, ScoringRequest from .protocol import RerankRequest, ScoreRequest, ScoringRequest
from .typing import ScoreData, ScoreInput, ScoringData from .typing import ScoreData, ScoreInput, ScoringData
from .utils import ( from .utils import (
......
...@@ -8,18 +8,16 @@ from pydantic import BaseModel, Field ...@@ -8,18 +8,16 @@ from pydantic import BaseModel, Field
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ClassifyRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
from ..base.protocol import ClassifyRequestMixin, PoolingBasicRequestMixin
from .typing import ScoreContentPartParam, ScoreInput from .typing import ScoreContentPartParam, ScoreInput
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): class ScoringRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
# --8<-- [start:scoring-common-params]
max_tokens_per_query: int = Field( max_tokens_per_query: int = Field(
default=0, default=0,
description=( description=(
...@@ -37,6 +35,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -37,6 +35,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
"applies to the combined query+document)." "applies to the combined query+document)."
), ),
) )
# --8<-- [end:scoring-common-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
...@@ -57,14 +56,16 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -57,14 +56,16 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
) )
class ScoreDataRequest(ScoreRequestMixin): class ScoreDataRequest(ScoringRequestMixin):
data_1: ScoreInput | list[ScoreInput] data_1: ScoreInput | list[ScoreInput]
data_2: ScoreInput | list[ScoreInput] data_2: ScoreInput | list[ScoreInput]
class ScoreQueriesDocumentsRequest(ScoreRequestMixin): class ScoreQueriesDocumentsRequest(ScoringRequestMixin):
# --8<-- [start:score-request-params]
queries: ScoreInput | list[ScoreInput] queries: ScoreInput | list[ScoreInput]
documents: ScoreInput | list[ScoreInput] documents: ScoreInput | list[ScoreInput]
# --8<-- [end:score-request-params]
@property @property
def data_1(self): def data_1(self):
...@@ -75,7 +76,7 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin): ...@@ -75,7 +76,7 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
return self.documents return self.documents
class ScoreQueriesItemsRequest(ScoreRequestMixin): class ScoreQueriesItemsRequest(ScoringRequestMixin):
queries: ScoreInput | list[ScoreInput] queries: ScoreInput | list[ScoreInput]
items: ScoreInput | list[ScoreInput] items: ScoreInput | list[ScoreInput]
...@@ -88,7 +89,7 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin): ...@@ -88,7 +89,7 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
return self.items return self.items
class ScoreTextRequest(ScoreRequestMixin): class ScoreTextRequest(ScoringRequestMixin):
text_1: ScoreInput | list[ScoreInput] text_1: ScoreInput | list[ScoreInput]
text_2: ScoreInput | list[ScoreInput] text_2: ScoreInput | list[ScoreInput]
...@@ -109,10 +110,12 @@ ScoreRequest: TypeAlias = ( ...@@ -109,10 +110,12 @@ ScoreRequest: TypeAlias = (
) )
class RerankRequest(ScoreRequestMixin): class RerankRequest(ScoringRequestMixin):
# --8<-- [start:rerank-request-params]
query: ScoreInput query: ScoreInput
documents: ScoreInput | list[ScoreInput] documents: ScoreInput | list[ScoreInput]
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
# --8<-- [end:rerank-request-params]
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
......
...@@ -6,8 +6,6 @@ from fastapi.responses import JSONResponse, Response ...@@ -6,8 +6,6 @@ from fastapi.responses import JSONResponse, Response
from vllm import PoolingParams from vllm import PoolingParams
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.v1.pool.late_interaction import ( from vllm.v1.pool.late_interaction import (
...@@ -15,6 +13,8 @@ from vllm.v1.pool.late_interaction import ( ...@@ -15,6 +13,8 @@ from vllm.v1.pool.late_interaction import (
build_late_interaction_query_params, build_late_interaction_query_params,
) )
from ..base.io_processor import PoolingIOProcessor
from ..base.serving import PoolingServing
from .io_processor import ScoringIOProcessors, ScoringServeContext from .io_processor import ScoringIOProcessors, ScoringServeContext
from .protocol import ( from .protocol import (
RerankDocument, RerankDocument,
...@@ -90,7 +90,7 @@ class ServingScores(PoolingServing): ...@@ -90,7 +90,7 @@ class ServingScores(PoolingServing):
ctx.request.top_n if ctx.request.top_n > 0 else len(final_res_batch), ctx.request.top_n if ctx.request.top_n > 0 else len(final_res_batch),
) )
else: else:
raise NotImplementedError("") raise ValueError(f"Invalid {self.request_id_prefix} request type")
def _request_output_to_score_response( def _request_output_to_score_response(
self, self,
......
...@@ -9,29 +9,30 @@ from fastapi import Request ...@@ -9,29 +9,30 @@ from fastapi import Request
from pydantic import ConfigDict from pydantic import ConfigDict
from vllm import PoolingParams, PoolingRequestOutput, PromptType from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest
from .classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.pooling.embed.protocol import ( from .embed.protocol import (
CohereEmbedRequest, CohereEmbedRequest,
EmbeddingBytesResponse, EmbeddingBytesResponse,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.pooling.protocol import ( from .pooling.protocol import (
IOProcessorRequest, IOProcessorRequest,
PoolingBytesResponse, PoolingBytesResponse,
PoolingChatRequest, PoolingChatRequest,
PoolingCompletionRequest, PoolingCompletionRequest,
PoolingResponse, PoolingResponse,
) )
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse from .scoring.protocol import ScoringRequest, ScoringResponse
from vllm.entrypoints.pooling.scoring.typing import ScoringData from .scoring.typing import ScoringData
from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest
PoolingCompletionLikeRequest: TypeAlias = ( PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest EmbeddingCompletionRequest
...@@ -65,6 +66,8 @@ PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) ...@@ -65,6 +66,8 @@ PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True) @dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]): class PoolingServeContext(Generic[PoolingRequestT]):
model_config = ConfigDict(arbitrary_types_allowed=True)
request: PoolingRequestT request: PoolingRequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
...@@ -74,14 +77,14 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -74,14 +77,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_inputs: Sequence[EngineInput] | None = None engine_inputs: Sequence[EngineInput] | None = None
prompt_request_ids: list[str] | None = None prompt_request_ids: list[str] | None = None
intermediates: Any | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None None
) )
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list) final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) ## for Long Text Embedding with Chunked Processing
original_engine_inputs: Sequence[EngineInput] | None = None
## for bi-encoder & late-interaction ## for bi-encoder & late-interaction
n_queries: int | None = None n_queries: int | None = None
......
...@@ -13,12 +13,13 @@ from fastapi.responses import JSONResponse, Response ...@@ -13,12 +13,13 @@ from fastapi.responses import JSONResponse, Response
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.generate.factories import get_generate_invocation_types
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServingBase from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.utils import enable_scoring_api from vllm.entrypoints.pooling.factories import get_pooling_invocation_types
from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13) # (requires typing_extensions >= 4.13)
...@@ -27,80 +28,6 @@ GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None] ...@@ -27,80 +28,6 @@ GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
def get_invocation_types(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
if "generate" in supported_tasks:
from vllm.entrypoints.openai.chat_completion.api_router import (
chat,
create_chat_completion,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.api_router import (
completion,
create_completion,
)
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
INVOCATION_TYPES += [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
]
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.api_router import (
create_embedding,
embedding,
)
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
INVOCATION_TYPES += [
(EmbeddingRequest, (embedding, create_embedding)),
]
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.api_router import (
classify,
create_classify,
)
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
INVOCATION_TYPES += [
(ClassificationRequest, (classify, create_classify)),
]
if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.scoring.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.scoring.protocol import RerankRequest
INVOCATION_TYPES += [
(RerankRequest, (rerank, do_rerank)),
]
from vllm.entrypoints.pooling.scoring.api_router import create_score, score
from vllm.entrypoints.pooling.scoring.protocol import ScoreRequest
INVOCATION_TYPES += [
(ScoreRequest, (score, create_score)),
]
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
INVOCATION_TYPES += [
(PoolingRequest, (pooling, create_pooling)),
]
return INVOCATION_TYPES
def attach_router( def attach_router(
app: FastAPI, app: FastAPI,
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
...@@ -109,7 +36,10 @@ def attach_router( ...@@ -109,7 +36,10 @@ def attach_router(
router = APIRouter() router = APIRouter()
# NOTE: Construct the TypeAdapters only once # NOTE: Construct the TypeAdapters only once
INVOCATION_TYPES = get_invocation_types(supported_tasks, model_config) INVOCATION_TYPES = get_generate_invocation_types(
supported_tasks, model_config
) + get_pooling_invocation_types(supported_tasks, model_config)
INVOCATION_VALIDATORS = [ INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint)) (pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES for request_type, (get_handler, endpoint) in INVOCATION_TYPES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment