Unverified Commit 0c492b78 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Deprecation] Remove fallbacks for Embeddings API (#18795)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 0f0926b4
...@@ -797,17 +797,12 @@ class ModelConfig: ...@@ -797,17 +797,12 @@ class ModelConfig:
else: else:
# Aliases # Aliases
if task_option == "embedding": if task_option == "embedding":
preferred_task = self._get_preferred_task( msg = ("The 'embedding' task has been renamed to "
architectures, supported_tasks) "'embed', please use the new name. The old name "
if preferred_task != "embed": "will be removed in v1.0.")
msg = ("The 'embedding' task will be restricted to " warnings.warn(msg, DeprecationWarning, stacklevel=2)
"embedding models in a future release. Please "
"pass `--task classify`, `--task score`, or " task_option = "embed"
"`--task reward` explicitly for other pooling "
"models.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_option = preferred_task or "embed"
if task_option not in supported_tasks: if task_option not in supported_tasks:
msg = ( msg = (
......
...@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager ...@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from functools import partial from functools import partial
from http import HTTPStatus from http import HTTPStatus
from json import JSONDecodeError from json import JSONDecodeError
from typing import Annotated, Optional, Union from typing import Annotated, Optional
import prometheus_client import prometheus_client
import regex as re import regex as re
...@@ -59,9 +59,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -59,9 +59,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse, ErrorResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
PoolingChatRequest, PoolingChatRequest,
PoolingCompletionRequest, PoolingCompletionRequest,
...@@ -627,37 +625,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -627,37 +625,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request) handler = embedding(raw_request)
if handler is None: if handler is None:
fallback_handler = pooling(raw_request) return base(raw_request).create_error_response(
if fallback_handler is None: message="The model does not support Embeddings API")
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning( generator = await handler.create_embedding(request, raw_request)
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
......
...@@ -7,7 +7,7 @@ from dataclasses import dataclass ...@@ -7,7 +7,7 @@ from dataclasses import dataclass
from typing import Any, Generic, Optional, Union from typing import Any, Generic, Optional, Union
import torch import torch
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -76,14 +76,6 @@ class PoolingOutput: ...@@ -76,14 +76,6 @@ class PoolingOutput:
return (isinstance(other, self.__class__) and bool( return (isinstance(other, self.__class__) and bool(
(self.data == other.data).all())) (self.data == other.data).all()))
@property
@deprecated("`LLM.encode()` now stores raw outputs in the `data` "
"attribute. To return embeddings, use `LLM.embed()`. "
"To return class probabilities, use `LLM.classify()` "
"and access the `probs` attribute. ")
def embedding(self) -> list[float]:
return self.data.tolist()
class RequestOutput: class RequestOutput:
"""The output data of a completion request to the LLM. """The output data of a completion request to the LLM.
...@@ -506,12 +498,6 @@ class ScoringOutput: ...@@ -506,12 +498,6 @@ class ScoringOutput:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ScoringOutput(score={self.score})" return f"ScoringOutput(score={self.score})"
@property
@deprecated("`LLM.score()` now returns scalar scores. "
"Please access it via the `score` attribute. ")
def embedding(self) -> list[float]:
return [self.score]
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]): class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment