Unverified Commit 22b64948 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][last/5] Make pooling entrypoints request schema consensus. (#31127)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 7c233dbb
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import Any, TypeAlias from typing import TypeAlias
from pydantic import Field from pydantic import Field
...@@ -78,11 +78,6 @@ class EmbeddingCompletionRequest( ...@@ -78,11 +78,6 @@ class EmbeddingCompletionRequest(
class EmbeddingChatRequest( class EmbeddingChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
): ):
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import Any, Generic, TypeAlias, TypeVar from typing import Generic, TypeAlias, TypeVar
from pydantic import Field from pydantic import Field
...@@ -65,11 +65,6 @@ class PoolingChatRequest( ...@@ -65,11 +65,6 @@ class PoolingChatRequest(
): ):
task: PoolingTask | None = None task: PoolingTask | None = None
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import Any, TypeAlias from typing import TypeAlias
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -23,13 +23,6 @@ from vllm.utils import random_uuid ...@@ -23,13 +23,6 @@ from vllm.utils import random_uuid
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
# --8<-- [start:score-extra-params]
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
# --8<-- [end:score-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
...@@ -106,13 +99,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -106,13 +99,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
documents: ScoreInputs documents: ScoreInputs
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
# --8<-- [start:rerank-extra-params]
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
# --8<-- [end:rerank-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def print_embeddings(embeds: list[float]):
embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment