Unverified Commit 62de4f42 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Resettle pooling entrypoints (#29634)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 83805a60
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import FastAPI
def register_pooling_api_routers(app: FastAPI):
from vllm.entrypoints.pooling.classify.api_router import router as classify_router
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(classify_router)
app.include_router(embed_router)
app.include_router(score_router)
app.include_router(pooling_router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter()
def classify(request: Request) -> ServingClassification | None:
return request.app.state.openai_serving_classification
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request):
handler = classify(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Classification API"
)
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Annotated, Any, TypeAlias
from pydantic import (
Field,
)
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
from vllm.utils import random_uuid
class ClassificationCompletionRequest(OpenAIBaseModel):
model: str | None = None
input: list[str] | str
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:classification-extra-params]
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:classification-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)
class ClassificationChatRequest(OpenAIBaseModel):
model: str | None = None
messages: list[ChatCompletionMessageParam]
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:chat-classification-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
)
class ClassificationData(OpenAIBaseModel):
index: int
label: str | None
probs: list[float]
num_classes: int
class ClassificationResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[ClassificationData]
usage: UsageInfo
...@@ -13,11 +13,6 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption ...@@ -13,11 +13,6 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationData,
ClassificationRequest,
ClassificationResponse,
ErrorResponse, ErrorResponse,
UsageInfo, UsageInfo,
) )
...@@ -27,6 +22,13 @@ from vllm.entrypoints.openai.serving_engine import ( ...@@ -27,6 +22,13 @@ from vllm.entrypoints.openai.serving_engine import (
ServeContext, ServeContext,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationData,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.outputs import ClassificationOutput, PoolingRequestOutput
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter()
def embedding(request: Request) -> OpenAIServingEmbedding | None:
return request.app.state.openai_serving_embedding
@router.post(
"/v1/embeddings",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_embedding(
request: EmbeddingRequest,
raw_request: Request,
):
handler = embedding(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Embeddings API"
)
try:
generator = await handler.create_embedding(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, EmbeddingBytesResponse):
return StreamingResponse(
content=generator.body,
headers={"metadata": generator.metadata},
media_type=generator.media_type,
)
assert_never(generator)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Annotated, Any, TypeAlias
from pydantic import (
Field,
model_validator,
)
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str | None = None
input: list[int] | list[list[int]] | str | list[str]
encoding_format: EncodingFormat = "float"
dimensions: int | None = None
user: str | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:embedding-extra-params]
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
normalize: bool | None = Field(
default=None,
description="Whether to normalize the embeddings outputs. Default is True.",
)
embed_dtype: EmbedDType = Field(
default="float32",
description=(
"What dtype to use for encoding. Default to using float32 for base64 "
"encoding to match the OpenAI python client behavior. "
"This parameter will affect base64 and binary_response."
),
)
endianness: Endianness = Field(
default="native",
description=(
"What endianness to use for encoding. Default to using native for "
"base64 encoding to match the OpenAI python client behavior."
"This parameter will affect base64 and binary_response."
),
)
# --8<-- [end:embedding-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize,
)
class EmbeddingChatRequest(OpenAIBaseModel):
model: str | None = None
messages: list[ChatCompletionMessageParam]
encoding_format: EncodingFormat = "float"
dimensions: int | None = None
user: str | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
normalize: bool | None = Field(
default=None,
description="Whether to normalize the embeddings outputs. Default is True.",
)
embed_dtype: EmbedDType = Field(
default="float32",
description=(
"What dtype to use for encoding. Default to using float32 for base64 "
"encoding to match the OpenAI python client behavior. "
"This parameter will affect base64 and binary_response."
),
)
endianness: Endianness = Field(
default="native",
description=(
"What endianness to use for encoding. Default to using native for "
"base64 encoding to match the OpenAI python client behavior."
"This parameter will affect base64 and binary_response."
),
)
# --8<-- [end:chat-embedding-extra-params]
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize,
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: list[float] | str
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[EmbeddingResponseData]
usage: UsageInfo
class EmbeddingBytesResponse(OpenAIBaseModel):
body: list[bytes]
metadata: str
media_type: str = "application/octet-stream"
...@@ -13,12 +13,6 @@ from vllm.engine.protocol import EngineClient ...@@ -13,12 +13,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, ErrorResponse,
UsageInfo, UsageInfo,
) )
...@@ -29,6 +23,14 @@ from vllm.entrypoints.openai.serving_engine import ( ...@@ -29,6 +23,14 @@ from vllm.entrypoints.openai.serving_engine import (
TextTokensPrompt, TextTokensPrompt,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorResponse,
PoolingBytesResponse,
PoolingRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter()
def pooling(request: Request) -> OpenAIServingPooling | None:
return request.app.state.openai_serving_pooling
@router.post(
"/pooling",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Pooling API"
)
try:
generator = await handler.create_pooling(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, PoolingBytesResponse):
return StreamingResponse(
content=generator.body,
headers={"metadata": generator.metadata},
media_type=generator.media_type,
)
assert_never(generator)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Generic, TypeAlias, TypeVar
from pydantic import (
Field,
)
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
)
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
class PoolingCompletionRequest(EmbeddingCompletionRequest):
task: PoolingTask | None = None
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"If it is a classify or token_classify task, the default is True; "
"for other tasks, this value should be None.",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize,
use_activation=get_use_activation(self),
)
class PoolingChatRequest(EmbeddingChatRequest):
task: PoolingTask | None = None
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"If it is a classify or token_classify task, the default is True; "
"for other tasks, this value should be None.",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize,
use_activation=get_use_activation(self),
)
T = TypeVar("T")
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
model: str | None = None
priority: int = Field(default=0)
"""
The priority of the request (lower means earlier handling;
default: 0). Any priority other than 0 will raise an error
if the served model does not use priority scheduling.
"""
data: T
task: PoolingTask = "plugin"
encoding_format: EncodingFormat = "float"
embed_dtype: EmbedDType = Field(
default="float32",
description=(
"What dtype to use for encoding. Default to using float32 for base64 "
"encoding to match the OpenAI python client behavior. "
"This parameter will affect base64 and binary_response."
),
)
endianness: Endianness = Field(
default="native",
description=(
"What endianness to use for encoding. Default to using native for "
"base64 encoding to match the OpenAI python client behavior."
"This parameter will affect base64 and binary_response."
),
)
def to_pooling_params(self):
return PoolingParams()
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None
"""
The request_id associated with this response
"""
created_at: int = Field(default_factory=lambda: int(time.time()))
data: T
"""
When using plugins IOProcessor plugins, the actual output is generated
by the plugin itself. Hence, we use a generic type for the response data
"""
PoolingRequest: TypeAlias = (
PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
)
class PoolingResponseData(OpenAIBaseModel):
index: int
object: str = "pooling"
data: list[list[float]] | list[float] | str
class PoolingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[PoolingResponseData]
usage: UsageInfo
class PoolingBytesResponse(OpenAIBaseModel):
body: list[bytes]
metadata: str
media_type: str = "application/octet-stream"
...@@ -16,6 +16,11 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption ...@@ -16,6 +16,11 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest, IOProcessorRequest,
IOProcessorResponse, IOProcessorResponse,
PoolingBytesResponse, PoolingBytesResponse,
...@@ -24,10 +29,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -24,10 +29,7 @@ from vllm.entrypoints.openai.protocol import (
PoolingRequest, PoolingRequest,
PoolingResponse, PoolingResponse,
PoolingResponseData, PoolingResponseData,
UsageInfo,
) )
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
router = APIRouter()
logger = init_logger(__name__)
def score(request: Request) -> ServingScores | None:
return request.app.state.openai_serving_scores
def rerank(request: Request) -> ServingScores | None:
return request.app.state.openai_serving_scores
@router.post(
"/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Score API"
)
try:
generator = await handler.create_score(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/v1/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
"have moved it to `/score`. Please update your client accordingly."
)
return await create_score(request, raw_request)
@router.post(
"/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Rerank (Score) API"
)
try:
generator = await handler.do_rerank(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/v1/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client "
"accordingly. (Note: Conforms to JinaAI rerank API)"
)
return await do_rerank(request, raw_request)
@router.post(
"/v2/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Annotated, Any
from pydantic import (
BaseModel,
Field,
)
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
from vllm.utils import random_uuid
class ScoreRequest(OpenAIBaseModel):
model: str | None = None
text_1: list[str] | str | ScoreMultiModalParam
text_2: list[str] | str | ScoreMultiModalParam
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:score-extra-params]
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:score-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)
class RerankRequest(OpenAIBaseModel):
model: str | None = None
query: str | ScoreMultiModalParam
documents: list[str] | ScoreMultiModalParam
top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:rerank-extra-params]
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:rerank-extra-params]
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)
class RerankDocument(BaseModel):
text: str | None = None
multi_modal: ScoreContentPartParam | None = None
class RerankResult(BaseModel):
index: int
document: RerankDocument
relevance_score: float
class RerankUsage(BaseModel):
total_tokens: int
class RerankResponse(OpenAIBaseModel):
id: str
model: str
usage: RerankUsage
results: list[RerankResult]
class ScoreResponseData(OpenAIBaseModel):
index: int
object: str = "score"
score: float
class ScoreResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: list[ScoreResponseData]
usage: UsageInfo
...@@ -11,6 +11,11 @@ from vllm.engine.protocol import EngineClient ...@@ -11,6 +11,11 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.score.protocol import (
RerankDocument, RerankDocument,
RerankRequest, RerankRequest,
RerankResponse, RerankResponse,
...@@ -19,10 +24,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -19,10 +24,7 @@ from vllm.entrypoints.openai.protocol import (
ScoreRequest, ScoreRequest,
ScoreResponse, ScoreResponse,
ScoreResponseData, ScoreResponseData,
UsageInfo,
) )
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.score_utils import ( from vllm.entrypoints.score_utils import (
ScoreContentPartParam, ScoreContentPartParam,
ScoreMultiModalParam, ScoreMultiModalParam,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import Awaitable, Callable
from http import HTTPStatus from http import HTTPStatus
from typing import Any
import model_hosting_container_standards.sagemaker as sagemaker_standards import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic import pydantic
...@@ -9,12 +11,56 @@ from fastapi import APIRouter, Depends, HTTPException, Request ...@@ -9,12 +11,56 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
INVOCATION_VALIDATORS,
base, base,
chat,
completion,
create_chat_completion,
create_completion,
health, health,
validate_json_request, validate_json_request,
) )
from vllm.entrypoints.openai.protocol import ErrorResponse from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
from vllm.entrypoints.pooling.score.api_router import (
create_score,
do_rerank,
rerank,
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
(EmbeddingRequest, (embedding, create_embedding)),
(ClassificationRequest, (classify, create_classify)),
(ScoreRequest, (score, create_score)),
(RerankRequest, (rerank, do_rerank)),
(PoolingRequest, (pooling, create_pooling)),
]
# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]
def register_sagemaker_routes(router: APIRouter): def register_sagemaker_routes(router: APIRouter):
......
...@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence ...@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment