Unverified Commit 038914b7 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move `task` outside of `PoolingParams.verify` (#33796)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent d2f4a71c
......@@ -70,6 +70,7 @@ steps:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
......@@ -82,6 +83,7 @@ steps:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_
......
......@@ -63,6 +63,7 @@ steps:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
......@@ -75,6 +76,7 @@ steps:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_
......
......@@ -122,6 +122,7 @@ steps:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
......@@ -134,6 +135,7 @@ steps:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_
......
......@@ -469,6 +469,4 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
......@@ -757,6 +757,4 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
......@@ -138,17 +138,17 @@ def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
"""Test that ColBERT model does not support 'embed' task."""
task = "embed"
text = "What is the capital of France?"
pooling_response = requests.post(
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "embed",
"task": task,
},
)
# Should return error
assert pooling_response.status_code == 400
assert "Task embed is not supported" in pooling_response.text
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
......@@ -232,6 +232,4 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
......@@ -27,35 +27,24 @@ class MockModelConfig:
pooler_config: PoolerConfig
def test_task():
pooling_params = PoolingParams()
pooling_params.verify(task="score")
pooling_params = PoolingParams(task="score")
pooling_params.verify(task="score")
with pytest.raises(ValueError):
pooling_params.verify(task="classify")
def test_embed():
task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = classify_parameters + step_pooling_parameters
for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
......@@ -63,7 +52,6 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
task = "embed"
model_config = ModelConfig(
model_info.name,
task="auto",
tokenizer=model_info.name,
tokenizer_mode="auto",
trust_remote_code=False,
......@@ -71,37 +59,39 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
dtype="float16",
)
pooling_params = PoolingParams(dimensions=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, dimensions=None)
pooling_params.verify(model_config)
with pytest.raises(ValueError):
pooling_params = PoolingParams(dimensions=1)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, dimensions=1)
pooling_params.verify(model_config)
if model_info.is_matryoshka:
assert model_info.matryoshka_dimensions is not None
pooling_params = PoolingParams(dimensions=model_info.matryoshka_dimensions[0])
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(
task=task, dimensions=model_info.matryoshka_dimensions[0]
)
pooling_params.verify(model_config)
@pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = embed_parameters + step_pooling_parameters
for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
......@@ -111,14 +101,14 @@ def test_token_embed(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = classify_parameters
if pooling_type != "STEP":
......@@ -126,8 +116,8 @@ def test_token_embed(pooling_type: str):
for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
......@@ -137,14 +127,14 @@ def test_token_classify(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = embed_parameters
if pooling_type != "STEP":
......@@ -152,5 +142,5 @@ def test_token_classify(pooling_type: str):
for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
......@@ -1135,11 +1135,12 @@ class LLM:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
self._validate_and_add_requests(
prompts=prompts,
......@@ -1472,8 +1473,9 @@ class LLM:
if pooling_params is None:
pooling_params = PoolingParams(task="score")
elif pooling_params.task is None:
pooling_params.task = "score"
pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]()
prompts = list[PromptType]()
......@@ -1836,6 +1838,7 @@ class LLM:
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
supported_tasks=self.supported_tasks,
)
self.llm_engine.add_request(
......
......@@ -68,7 +68,6 @@ def init_pooling_state(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
......@@ -76,7 +75,7 @@ def init_pooling_state(
log_error_stack=args.log_error_stack,
)
)
if any(task in POOLING_TASKS for task in supported_tasks)
if any(t in supported_tasks for t in POOLING_TASKS)
else None
)
state.openai_serving_embedding = (
......
......@@ -6,19 +6,15 @@ from typing import Annotated, Any
from pydantic import Field, model_validator
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.logger import init_logger
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
logger = init_logger(__name__)
class PoolingBasicRequestMixin(OpenAIBaseModel):
# --8<-- [start:pooling-common-params]
......@@ -185,20 +181,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
)
# --8<-- [end:embed-extra-params]
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
)
class ClassifyRequestMixin(OpenAIBaseModel):
# --8<-- [start:classify-extra-params]
......@@ -208,9 +190,3 @@ class ClassifyRequestMixin(OpenAIBaseModel):
"`None` uses the pooler's default, which is `True` in most cases.",
)
# --8<-- [end:classify-extra-params]
def to_pooling_params(self):
return PoolingParams(
use_activation=self.use_activation,
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
)
......@@ -6,6 +6,7 @@ from typing import Any, TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
......@@ -14,9 +15,12 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
......@@ -33,6 +37,13 @@ class ClassificationCompletionRequest(
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class ClassificationChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
......@@ -55,6 +66,13 @@ class ClassificationChatRequest(
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
......
......@@ -22,7 +22,6 @@ from vllm.entrypoints.pooling.classify.protocol import (
)
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
......@@ -159,18 +158,3 @@ class ServingClassification(OpenAIServing):
)
return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("classify", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params
......@@ -5,6 +5,7 @@ from typing import Any, TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
......@@ -13,9 +14,12 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
def _get_max_total_output_tokens(
model_config: ModelConfig,
......@@ -55,6 +59,21 @@ class EmbeddingCompletionRequest(
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
class EmbeddingChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
......@@ -82,6 +101,21 @@ class EmbeddingChatRequest(
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
......
......@@ -424,12 +424,6 @@ class OpenAIServingEmbedding(OpenAIServing):
if isinstance(pooling_params, ErrorResponse):
return pooling_params
# Verify and set the task for pooling params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
......@@ -463,8 +457,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
async def _collect_batch(
self,
......@@ -634,7 +627,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return None
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response(e)
async def create_embedding(
self,
......@@ -661,18 +654,3 @@ class OpenAIServingEmbedding(OpenAIServing):
)
return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params
......@@ -53,6 +53,7 @@ class PoolingCompletionRequest(
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
......@@ -90,6 +91,7 @@ class PoolingChatRequest(
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
......@@ -104,7 +106,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
task: PoolingTask = "plugin"
def to_pooling_params(self):
return PoolingParams()
return PoolingParams(task=self.task)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
......
......@@ -35,7 +35,6 @@ from vllm.entrypoints.pooling.utils import (
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -48,7 +47,6 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient,
models: OpenAIServingModels,
*,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
......@@ -62,7 +60,6 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack=log_error_stack,
)
self.supported_tasks = supported_tasks
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
......@@ -152,32 +149,6 @@ class OpenAIServingPooling(OpenAIServing):
else:
pooling_params = request.to_pooling_params()
pooling_task: PoolingTask
if request.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
else:
pooling_task = request.task
if pooling_task not in self.supported_tasks:
return self.create_error_response(
f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
)
try:
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......@@ -212,8 +183,7 @@ class OpenAIServingPooling(OpenAIServing):
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators)
......@@ -251,8 +221,7 @@ class OpenAIServingPooling(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
return response
......
......@@ -18,6 +18,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
......@@ -40,8 +41,9 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
......@@ -122,6 +124,13 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class RerankDocument(BaseModel):
text: str | None = None
......
......@@ -118,12 +118,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
pooling_params = request.to_pooling_params("embed")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......@@ -223,19 +218,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
# Use token_embed task for late interaction models
from vllm import PoolingParams
pooling_params = PoolingParams(
task="token_embed",
truncate_prompt_tokens=request.truncate_prompt_tokens,
use_activation=request.use_activation,
)
try:
pooling_params.verify("token_embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
pooling_params = request.to_pooling_params("token_embed")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......@@ -358,12 +341,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params()
try:
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
default_pooling_params = request.to_pooling_params("score")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......@@ -497,8 +475,7 @@ class ServingScores(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
async def do_rerank(
self, request: RerankRequest, raw_request: Request | None = None
......@@ -542,8 +519,7 @@ class ServingScores(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
def request_output_to_score_response(
self,
......
......@@ -72,15 +72,7 @@ class PoolingParams(
"""Returns a deep copy of the PoolingParams instance."""
return deepcopy(self)
def verify(
self, task: PoolingTask, model_config: "ModelConfig | None" = None
) -> None:
if self.task is None:
self.task = task
elif self.task != task:
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)
def verify(self, model_config: "ModelConfig") -> None:
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
......@@ -167,7 +159,7 @@ class PoolingParams(
if mds is not None:
if self.dimensions not in mds:
raise ValueError(
f'Model "{model_config.served_model_name}" '
f"Model {model_config.served_model_name!r} "
f"only supports {str(mds)} matryoshka dimensions, "
f"use other output dimensions will "
f"lead to poor results."
......@@ -179,7 +171,7 @@ class PoolingParams(
if self.use_activation is None:
self.use_activation = True
else:
raise ValueError(f"Unknown pooling task: {self.task}")
raise ValueError(f"Unknown pooling task: {self.task!r}")
def _verify_valid_parameters(self):
assert self.task is not None, "task must be set"
......@@ -194,7 +186,7 @@ class PoolingParams(
if invalid_parameters:
raise ValueError(
f"Task {self.task} only supports {valid_parameters} "
f"Task {self.task!r} only supports {valid_parameters} "
f"parameters, does not support "
f"{invalid_parameters} parameters"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment