Unverified Commit 4e8c3f1c authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][last/5] Improve pooling entrypoints | clean up. (#39675)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 5e5afafa
......@@ -4,12 +4,12 @@ from collections.abc import Sequence
from typing import Any
from vllm import PoolingParams, PoolingRequestOutput
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from ..base.io_processor import PoolingIOProcessor
from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext
from .protocol import IOProcessorRequest, IOProcessorResponse
......@@ -17,6 +17,8 @@ logger = init_logger(__name__)
class PluginWithoutIOProcessorPlugins(PoolingIOProcessor):
# Some models, such as Terratorch (tests/models/test_terratorch.py),
# use plugin tasks in the pooler but do not use IO Processor plugins.
name = "plugin"
......
......@@ -8,7 +8,11 @@ from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
from ..base.protocol import (
ChatRequestMixin,
ClassifyRequestMixin,
CompletionRequestMixin,
......@@ -16,9 +20,6 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
class PoolingCompletionRequest(
......
......@@ -11,27 +11,28 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.pooling.protocol import (
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils.serial_utils import EmbedDType, Endianness
from ..base.io_processor import PoolingIOProcessor
from ..base.serving import PoolingServingBase
from ..factories import init_pooling_io_processors
from ..typing import AnyPoolingRequest, PoolingServeContext
from ..utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
get_json_response_cls,
)
from .protocol import (
IOProcessorRequest,
PoolingBytesResponse,
PoolingRequest,
PoolingResponse,
PoolingResponseData,
)
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
get_json_response_cls,
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__)
......
......@@ -7,12 +7,6 @@ from typing import Any, TypeAlias
import torch.nn.functional as F
from vllm import PoolingParams, PoolingRequestOutput, TokensPrompt
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingServeContext,
)
from vllm.inputs import EngineInput
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import safe_apply_chat_template
......@@ -20,6 +14,12 @@ from vllm.tasks import PoolingTask
from vllm.utils.mistral import is_mistral_tokenizer
from ...chat_utils import ChatTemplateResolutionError
from ..base.io_processor import PoolingIOProcessor
from ..typing import (
OfflineInputsContext,
OfflineOutputsContext,
PoolingServeContext,
)
from .protocol import RerankRequest, ScoreRequest, ScoringRequest
from .typing import ScoreData, ScoreInput, ScoringData
from .utils import (
......
......@@ -8,18 +8,16 @@ from pydantic import BaseModel, Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ClassifyRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
from ..base.protocol import ClassifyRequestMixin, PoolingBasicRequestMixin
from .typing import ScoreContentPartParam, ScoreInput
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
class ScoringRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
# --8<-- [start:scoring-common-params]
max_tokens_per_query: int = Field(
default=0,
description=(
......@@ -37,6 +35,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
"applies to the combined query+document)."
),
)
# --8<-- [end:scoring-common-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
......@@ -57,14 +56,16 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
class ScoreDataRequest(ScoreRequestMixin):
class ScoreDataRequest(ScoringRequestMixin):
data_1: ScoreInput | list[ScoreInput]
data_2: ScoreInput | list[ScoreInput]
class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
class ScoreQueriesDocumentsRequest(ScoringRequestMixin):
# --8<-- [start:score-request-params]
queries: ScoreInput | list[ScoreInput]
documents: ScoreInput | list[ScoreInput]
# --8<-- [end:score-request-params]
@property
def data_1(self):
......@@ -75,7 +76,7 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
return self.documents
class ScoreQueriesItemsRequest(ScoreRequestMixin):
class ScoreQueriesItemsRequest(ScoringRequestMixin):
queries: ScoreInput | list[ScoreInput]
items: ScoreInput | list[ScoreInput]
......@@ -88,7 +89,7 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
return self.items
class ScoreTextRequest(ScoreRequestMixin):
class ScoreTextRequest(ScoringRequestMixin):
text_1: ScoreInput | list[ScoreInput]
text_2: ScoreInput | list[ScoreInput]
......@@ -109,10 +110,12 @@ ScoreRequest: TypeAlias = (
)
class RerankRequest(ScoreRequestMixin):
class RerankRequest(ScoringRequestMixin):
# --8<-- [start:rerank-request-params]
query: ScoreInput
documents: ScoreInput | list[ScoreInput]
top_n: int = Field(default_factory=lambda: 0)
# --8<-- [end:rerank-request-params]
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
......
......@@ -6,8 +6,6 @@ from fastapi.responses import JSONResponse, Response
from vllm import PoolingParams
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.v1.pool.late_interaction import (
......@@ -15,6 +13,8 @@ from vllm.v1.pool.late_interaction import (
build_late_interaction_query_params,
)
from ..base.io_processor import PoolingIOProcessor
from ..base.serving import PoolingServing
from .io_processor import ScoringIOProcessors, ScoringServeContext
from .protocol import (
RerankDocument,
......@@ -90,7 +90,7 @@ class ServingScores(PoolingServing):
ctx.request.top_n if ctx.request.top_n > 0 else len(final_res_batch),
)
else:
raise NotImplementedError("")
raise ValueError(f"Invalid {self.request_id_prefix} request type")
def _request_output_to_score_response(
self,
......
......@@ -9,29 +9,30 @@ from fastapi import Request
from pydantic import ConfigDict
from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.entrypoints.pooling.classify.protocol import (
from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest
from .classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
from .embed.protocol import (
CohereEmbedRequest,
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
from .pooling.protocol import (
IOProcessorRequest,
PoolingBytesResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
from vllm.entrypoints.pooling.scoring.typing import ScoringData
from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest
from .scoring.protocol import ScoringRequest, ScoringResponse
from .scoring.typing import ScoringData
PoolingCompletionLikeRequest: TypeAlias = (
EmbeddingCompletionRequest
......@@ -65,6 +66,8 @@ PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
model_config = ConfigDict(arbitrary_types_allowed=True)
request: PoolingRequestT
raw_request: Request | None = None
model_name: str
......@@ -74,14 +77,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
lora_request: LoRARequest | None = None
engine_inputs: Sequence[EngineInput] | None = None
prompt_request_ids: list[str] | None = None
intermediates: Any | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
## for Long Text Embedding with Chunked Processing
original_engine_inputs: Sequence[EngineInput] | None = None
## for bi-encoder & late-interaction
n_queries: int | None = None
......
......@@ -13,12 +13,13 @@ from fastapi.responses import JSONResponse, Response
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.generate.factories import get_generate_invocation_types
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.entrypoints.pooling.factories import get_pooling_invocation_types
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tasks import SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
......@@ -27,80 +28,6 @@ GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
def get_invocation_types(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
if "generate" in supported_tasks:
from vllm.entrypoints.openai.chat_completion.api_router import (
chat,
create_chat_completion,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.api_router import (
completion,
create_completion,
)
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
INVOCATION_TYPES += [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
]
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.api_router import (
create_embedding,
embedding,
)
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
INVOCATION_TYPES += [
(EmbeddingRequest, (embedding, create_embedding)),
]
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.api_router import (
classify,
create_classify,
)
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
INVOCATION_TYPES += [
(ClassificationRequest, (classify, create_classify)),
]
if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.scoring.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.scoring.protocol import RerankRequest
INVOCATION_TYPES += [
(RerankRequest, (rerank, do_rerank)),
]
from vllm.entrypoints.pooling.scoring.api_router import create_score, score
from vllm.entrypoints.pooling.scoring.protocol import ScoreRequest
INVOCATION_TYPES += [
(ScoreRequest, (score, create_score)),
]
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
INVOCATION_TYPES += [
(PoolingRequest, (pooling, create_pooling)),
]
return INVOCATION_TYPES
def attach_router(
app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
......@@ -109,7 +36,10 @@ def attach_router(
router = APIRouter()
# NOTE: Construct the TypeAdapters only once
INVOCATION_TYPES = get_invocation_types(supported_tasks, model_config)
INVOCATION_TYPES = get_generate_invocation_types(
supported_tasks, model_config
) + get_pooling_invocation_types(supported_tasks, model_config)
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment