Unverified Commit 62de4f42 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Resettle pooling entrypoints (#29634)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 83805a60
...@@ -149,6 +149,7 @@ mkdocs.yaml @hmellor ...@@ -149,6 +149,7 @@ mkdocs.yaml @hmellor
/examples/*/pooling/ @noooop /examples/*/pooling/ @noooop
/tests/models/*/pooling* @noooop /tests/models/*/pooling* @noooop
/tests/entrypoints/pooling @noooop /tests/entrypoints/pooling @noooop
/vllm/entrypoints/pooling @aarnphm @chaunceyjiang @noooop
/vllm/config/pooler.py @noooop /vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop /vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler.py @noooop /vllm/model_executor/layers/pooler.py @noooop
......
...@@ -77,7 +77,7 @@ The `parse_request` method is used for validating the user prompt and converting ...@@ -77,7 +77,7 @@ The `parse_request` method is used for validating the user prompt and converting
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples. An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.
......
...@@ -351,7 +351,7 @@ The following extra parameters are supported by default: ...@@ -351,7 +351,7 @@ The following extra parameters are supported by default:
??? code ??? code
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:embedding-extra-params" --8<-- "vllm/entrypoints/pooling/embed/protocol.py:embedding-extra-params"
``` ```
For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead: For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead:
...@@ -359,7 +359,7 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s ...@@ -359,7 +359,7 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
??? code ??? code
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params" --8<-- "vllm/entrypoints/pooling/embed/protocol.py:chat-embedding-extra-params"
``` ```
### Transcriptions API ### Transcriptions API
...@@ -629,7 +629,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported. ...@@ -629,7 +629,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
The following extra parameters are supported: The following extra parameters are supported:
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:classification-extra-params" --8<-- "vllm/entrypoints/pooling/classify/protocol.py:classification-extra-params"
``` ```
### Score API ### Score API
...@@ -834,7 +834,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported. ...@@ -834,7 +834,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
The following extra parameters are supported: The following extra parameters are supported:
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:score-extra-params" --8<-- "vllm/entrypoints/pooling/score/protocol.py:score-extra-params"
``` ```
### Re-rank API ### Re-rank API
...@@ -915,7 +915,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported. ...@@ -915,7 +915,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
The following extra parameters are supported: The following extra parameters are supported:
```python ```python
--8<-- "vllm/entrypoints/openai/protocol.py:rerank-extra-params" --8<-- "vllm/entrypoints/pooling/score/protocol.py:rerank-extra-params"
``` ```
## Ray Serve LLM ## Ray Serve LLM
......
...@@ -7,7 +7,7 @@ import tempfile ...@@ -7,7 +7,7 @@ import tempfile
import pytest import pytest
from vllm.entrypoints.openai.protocol import BatchRequestOutput from vllm.entrypoints.openai.run_batch import BatchRequestOutput
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue DTYPE = "float32" # Use float32 to avoid NaN issue
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
import requests import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls" VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
MAXIMUM_VIDEOS = 1 MAXIMUM_VIDEOS = 1
......
...@@ -15,10 +15,8 @@ import torch.nn.functional as F ...@@ -15,10 +15,8 @@ import torch.nn.functional as F
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
from tests.models.utils import check_embeddings_close from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
EmbeddingResponse, from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
PoolingResponse,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
......
...@@ -11,7 +11,7 @@ from tests.conftest import HfRunner ...@@ -11,7 +11,7 @@ from tests.conftest import HfRunner
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
from tests.models.utils import EmbedModelInfo from tests.models.utils import EmbedModelInfo
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -15,7 +15,7 @@ import pytest ...@@ -15,7 +15,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -8,7 +8,7 @@ import requests ...@@ -8,7 +8,7 @@ import requests
from transformers import AutoProcessor from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
from tests.models.utils import check_embeddings_close from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE, EMBED_DTYPE_TO_TORCH_DTYPE,
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.entrypoints.pooling.score.protocol import RerankResponse
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -9,7 +9,7 @@ import torch.nn.functional as F ...@@ -9,7 +9,7 @@ import torch.nn.functional as F
from torch import tensor from torch import tensor
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ScoreResponse from vllm.entrypoints.pooling.score.protocol import ScoreResponse
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -18,7 +18,10 @@ from einops import rearrange ...@@ -18,7 +18,10 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
)
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
......
...@@ -7,7 +7,7 @@ import requests ...@@ -7,7 +7,7 @@ import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
......
...@@ -14,7 +14,7 @@ import socket ...@@ -14,7 +14,7 @@ import socket
import tempfile import tempfile
import uuid import uuid
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
...@@ -54,29 +54,16 @@ from vllm.entrypoints.openai.orca_metrics import metrics_header ...@@ -54,29 +54,16 @@ from vllm.entrypoints.openai.orca_metrics import metrics_header
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ClassificationRequest,
ClassificationResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingBytesResponse,
EmbeddingRequest,
EmbeddingResponse,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
GenerateRequest, GenerateRequest,
GenerateResponse, GenerateResponse,
IOProcessorResponse,
PoolingBytesResponse,
PoolingRequest,
PoolingResponse,
RerankRequest,
RerankResponse,
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
ScoreRequest,
ScoreResponse,
StreamingResponsesResponse, StreamingResponsesResponse,
TokenizeRequest, TokenizeRequest,
TokenizeResponse, TokenizeResponse,
...@@ -86,17 +73,13 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -86,17 +73,13 @@ from vllm.entrypoints.openai.protocol import (
TranslationResponse, TranslationResponse,
) )
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_classification import ServingClassification
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import ( from vllm.entrypoints.openai.serving_models import (
BaseModelPath, BaseModelPath,
OpenAIServingModels, OpenAIServingModels,
) )
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.entrypoints.openai.serving_tokens import ServingTokens from vllm.entrypoints.openai.serving_tokens import ServingTokens
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
...@@ -104,6 +87,11 @@ from vllm.entrypoints.openai.serving_transcription import ( ...@@ -104,6 +87,11 @@ from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranslation, OpenAIServingTranslation,
) )
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import ( from vllm.entrypoints.utils import (
cli_env_setup, cli_env_setup,
...@@ -254,15 +242,6 @@ async def build_async_engine_client_from_engine_args( ...@@ -254,15 +242,6 @@ async def build_async_engine_client_from_engine_args(
async_llm.shutdown() async_llm.shutdown()
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
)
router = APIRouter() router = APIRouter()
...@@ -324,26 +303,6 @@ def completion(request: Request) -> OpenAIServingCompletion | None: ...@@ -324,26 +303,6 @@ def completion(request: Request) -> OpenAIServingCompletion | None:
return request.app.state.openai_serving_completion return request.app.state.openai_serving_completion
def pooling(request: Request) -> OpenAIServingPooling | None:
return request.app.state.openai_serving_pooling
def embedding(request: Request) -> OpenAIServingEmbedding | None:
return request.app.state.openai_serving_embedding
def score(request: Request) -> ServingScores | None:
return request.app.state.openai_serving_scores
def classify(request: Request) -> ServingClassification | None:
return request.app.state.openai_serving_classification
def rerank(request: Request) -> ServingScores | None:
return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization: def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization return request.app.state.openai_serving_tokenization
...@@ -817,166 +776,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -817,166 +776,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/embeddings",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_embedding(
request: EmbeddingRequest,
raw_request: Request,
):
handler = embedding(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API"
)
try:
generator = await handler.create_embedding(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, EmbeddingBytesResponse):
return StreamingResponse(
content=generator.body,
headers={"metadata": generator.metadata},
media_type=generator.media_type,
)
assert_never(generator)
@router.post(
"/pooling",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Pooling API"
)
try:
generator = await handler.create_pooling(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, PoolingBytesResponse):
return StreamingResponse(
content=generator.body,
headers={"metadata": generator.metadata},
media_type=generator.media_type,
)
assert_never(generator)
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest, raw_request: Request):
handler = classify(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Classification API"
)
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Score API"
)
try:
generator = await handler.create_score(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/v1/score",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
"have moved it to `/score`. Please update your client accordingly."
)
return await create_score(request, raw_request)
@router.post( @router.post(
"/v1/audio/transcriptions", "/v1/audio/transcriptions",
responses={ responses={
...@@ -1055,70 +854,6 @@ async def create_translations( ...@@ -1055,70 +854,6 @@ async def create_translations(
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Rerank (Score) API"
)
try:
generator = await handler.do_rerank(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/v1/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client "
"accordingly. (Note: Conforms to JinaAI rerank API)"
)
return await do_rerank(request, raw_request)
@router.post(
"/v2/rerank",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
if envs.VLLM_SERVER_DEV_MODE: if envs.VLLM_SERVER_DEV_MODE:
logger.warning( logger.warning(
"SECURITY WARNING: Development endpoints are enabled! " "SECURITY WARNING: Development endpoints are enabled! "
...@@ -1285,30 +1020,6 @@ async def is_scaling_elastic_ep(raw_request: Request): ...@@ -1285,30 +1020,6 @@ async def is_scaling_elastic_ep(raw_request: Request):
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep}) return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
(EmbeddingRequest, (embedding, create_embedding)),
(ClassificationRequest, (classify, create_classify)),
(ScoreRequest, (score, create_score)),
(RerankRequest, (rerank, do_rerank)),
(PoolingRequest, (pooling, create_pooling)),
]
# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]
@router.post( @router.post(
"/inference/v1/generate", "/inference/v1/generate",
dependencies=[Depends(validate_json_request)], dependencies=[Depends(validate_json_request)],
...@@ -1653,12 +1364,16 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -1653,12 +1364,16 @@ def build_app(args: Namespace) -> FastAPI:
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
register_sagemaker_routes(router) register_sagemaker_routes(router)
app.include_router(router) app.include_router(router)
app.root_path = args.root_path app.root_path = args.root_path
mount_metrics(app) mount_metrics(app)
from vllm.entrypoints.pooling import register_pooling_api_routers
register_pooling_api_routers(app)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=args.allowed_origins, allow_origins=args.allowed_origins,
......
This diff is collapsed.
...@@ -7,29 +7,35 @@ from argparse import Namespace ...@@ -7,29 +7,35 @@ from argparse import Namespace
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from http import HTTPStatus from http import HTTPStatus
from io import StringIO from io import StringIO
from typing import Any, TypeAlias
import aiohttp import aiohttp
import torch import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from pydantic import TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from tqdm import tqdm from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
BatchRequestInput, ChatCompletionRequest,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse, ChatCompletionResponse,
EmbeddingResponse,
ErrorResponse, ErrorResponse,
RerankResponse, OpenAIBaseModel,
ScoreResponse,
) )
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
RerankResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -39,6 +45,84 @@ from vllm.version import __version__ as VLLM_VERSION ...@@ -39,6 +45,84 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
BatchRequestInputBody: TypeAlias = (
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameters of the request.
body: BatchRequestInputBody
@field_validator("body", mode="plain")
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url: str = info.data["url"]
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url.endswith("/score"):
return ScoreRequest.model_validate(value)
if url.endswith("/rerank"):
return RerankRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: (
ChatCompletionResponse
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| None
) = None
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: BatchResponseData | None
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Any | None
def make_arg_parser(parser: FlexibleArgumentParser): def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument( parser.add_argument(
"-i", "-i",
......
...@@ -18,6 +18,28 @@ from pydantic import ConfigDict, TypeAdapter ...@@ -18,6 +18,28 @@ from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import TypedDict from typing import TypedDict
else: else:
...@@ -45,29 +67,16 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -45,29 +67,16 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
DetokenizeRequest, DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
FunctionCall, FunctionCall,
FunctionDefinition, FunctionDefinition,
GenerateRequest, GenerateRequest,
GenerateResponse, GenerateResponse,
IOProcessorRequest,
PoolingResponse,
RerankRequest,
ResponsesRequest, ResponsesRequest,
ScoreRequest,
ScoreResponse,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeVar from typing import TypeVar
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
...@@ -35,3 +38,12 @@ def maybe_filter_parallel_tool_calls( ...@@ -35,3 +38,12 @@ def maybe_filter_parallel_tool_calls(
] ]
return choice return choice
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise RequestValidationError(
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment