"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "841d53aaa8d674f2c9f72503e77f75e5ffa79c71"
Unverified Commit fafca38a authored by Walter Beller-Morales's avatar Walter Beller-Morales Committed by GitHub
Browse files

[BugFix][Frontend] apply task instruction as system prompt in cohere v2/embed (#38362)


Signed-off-by: default avatarwalterbm <walter.beller.morales@gmail.com>
parent aa4eb0db
...@@ -57,16 +57,25 @@ def _openai_embed( ...@@ -57,16 +57,25 @@ def _openai_embed(
return [item["embedding"] for item in resp.json()["data"]] return [item["embedding"] for item in resp.json()["data"]]
def _cosine_sim(a: list[float], b: list[float]) -> float:
va, vb = np.array(a), np.array(b)
return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))
def test_single_text_parity(server: RemoteOpenAIServer): def test_single_text_parity(server: RemoteOpenAIServer):
"""A single text should produce identical embeddings via both APIs.""" """A single text should produce equivalent embeddings via both APIs."""
texts = ["the quick brown fox jumps over the lazy dog"] texts = ["the quick brown fox jumps over the lazy dog"]
v2 = _cohere_embed(server, texts) v2 = _cohere_embed(server, texts)
v1 = _openai_embed(server, texts) v1 = _openai_embed(server, texts)
np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5) # Full-suite BF16 runs can introduce tiny numerical drift even when both
# endpoints are functionally equivalent, so compare semantic equivalence
# instead of exact elementwise equality.
cos = _cosine_sim(v2[0], v1[0])
assert cos > 0.9999, f"single-text parity failed, cosine={cos}"
def test_batch_parity(server: RemoteOpenAIServer): def test_batch_parity(server: RemoteOpenAIServer):
"""A batch of texts should produce identical embeddings via both APIs, """A batch of texts should produce equivalent embeddings via both APIs,
in the same order.""" in the same order."""
texts = [ texts = [
"machine learning", "machine learning",
...@@ -76,8 +85,18 @@ def test_batch_parity(server: RemoteOpenAIServer): ...@@ -76,8 +85,18 @@ def test_batch_parity(server: RemoteOpenAIServer):
v2 = _cohere_embed(server, texts) v2 = _cohere_embed(server, texts)
v1 = _openai_embed(server, texts) v1 = _openai_embed(server, texts)
assert len(v2) == len(v1) == 3 assert len(v2) == len(v1) == 3
similarities = np.array(
[[_cosine_sim(v2_emb, v1_emb) for v1_emb in v1] for v2_emb in v2]
)
for i in range(3): for i in range(3):
np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}") assert int(np.argmax(similarities[i])) == i, (
f"batch parity order mismatch at index {i}: "
f"similarities={similarities[i].tolist()}"
)
assert similarities[i, i] > 0.9999, (
f"batch parity failed at index {i}, cosine={similarities[i, i]}"
)
def test_token_count_parity(server: RemoteOpenAIServer): def test_token_count_parity(server: RemoteOpenAIServer):
......
...@@ -6,8 +6,11 @@ import pytest ...@@ -6,8 +6,11 @@ import pytest
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
CohereEmbedInput,
CohereEmbedRequest, CohereEmbedRequest,
) )
from vllm.entrypoints.pooling.typing import PoolingServeContext
class TestResolveTruncation: class TestResolveTruncation:
...@@ -206,3 +209,116 @@ class TestValidateInputType: ...@@ -206,3 +209,116 @@ class TestValidateInputType:
handler = self._make_handler({"a": "", "b": ""}) handler = self._make_handler({"a": "", "b": ""})
with pytest.raises(ValueError, match="Supported values: a, b"): with pytest.raises(ValueError, match="Supported values: a, b"):
handler._validate_input_type("z") handler._validate_input_type("z")
class TestPreProcessCohereOnline:
"""Unit tests for EmbedIOProcessor._pre_process_cohere_online."""
@staticmethod
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
return PoolingServeContext(
request=CohereEmbedRequest(model="test", **request_kwargs),
model_name="test",
request_id="embd-test",
)
@staticmethod
def _make_handler():
handler = object.__new__(EmbedIOProcessor)
handler._validate_input_type = lambda _input_type: None
return handler
def test_text_only_without_task_prefix_uses_completion_path(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"])
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["completion"]
handler._get_task_instruction_prefix = lambda _input_type: None
handler._has_chat_template = lambda: False
handler._preprocess_completion_online = preprocess_completion
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("text-only request should not require chat rendering")
)
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["completion"]
assert calls == [("completion", ["hello"])]
def test_text_only_falls_back_to_prefixed_completion_without_template(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["fallback"]
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: False
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("chat rendering should be skipped without a template")
)
handler._preprocess_completion_online = preprocess_completion
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["fallback"]
assert calls == [("completion", ["query: hello"])]
def test_text_only_with_template_uses_chat_path(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def batch_render_chat(
request,
all_messages,
truncate_prompt_tokens,
truncation_side,
):
calls.append(
(
"chat",
{
"request": request,
"all_messages": all_messages,
"truncate_prompt_tokens": truncate_prompt_tokens,
"truncation_side": truncation_side,
},
)
)
return ["chat"]
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: True
handler._batch_render_chat = batch_render_chat
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
pytest.fail("completion path should be skipped when a template exists")
)
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["chat"]
assert calls == [
(
"chat",
{
"request": ctx.request,
"all_messages": [
handler._mixed_input_to_messages(
CohereEmbedInput(
content=[CohereEmbedContent(type="text", text="hello")]
),
task_prefix="query: ",
)
],
"truncate_prompt_tokens": -1,
"truncation_side": None,
},
)
]
...@@ -18,6 +18,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -18,6 +18,7 @@ from vllm.entrypoints.chat_utils import (
) )
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
CohereEmbedInput, CohereEmbedInput,
CohereEmbedRequest, CohereEmbedRequest,
EmbeddingChatRequest, EmbeddingChatRequest,
...@@ -28,6 +29,7 @@ from vllm.inputs import EngineInput, tokens_input ...@@ -28,6 +29,7 @@ from vllm.inputs import EngineInput, tokens_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.renderers import merge_kwargs from vllm.renderers import merge_kwargs
from vllm.renderers.hf import resolve_chat_template
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -284,13 +286,27 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -284,13 +286,27 @@ class EmbedIOProcessor(PoolingIOProcessor):
) -> list[ChatCompletionMessageParam]: ) -> list[ChatCompletionMessageParam]:
"""Build chat messages from a mixed text+image input. """Build chat messages from a mixed text+image input.
When *task_prefix* is given, it is prepended to each text part. When *task_prefix* is given, it is used as the system prompt.
""" """
messages: list[ChatCompletionMessageParam] = []
if task_prefix is not None:
messages.append(
CustomChatCompletionMessageParam(
role="system",
content=[
ChatCompletionContentPartTextParam(
type="text", text=task_prefix
)
],
)
)
parts: list[ChatCompletionContentPartParam] = [] parts: list[ChatCompletionContentPartParam] = []
for item in inp.content: for item in inp.content:
if item.type == "text" and item.text is not None: if item.type == "text" and item.text is not None:
text = task_prefix + item.text if task_prefix else item.text parts.append(
parts.append(ChatCompletionContentPartTextParam(type="text", text=text)) ChatCompletionContentPartTextParam(type="text", text=item.text)
)
elif item.type == "image_url" and item.image_url is not None: elif item.type == "image_url" and item.image_url is not None:
parts.append( parts.append(
ChatCompletionContentPartImageParam( ChatCompletionContentPartImageParam(
...@@ -298,7 +314,8 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -298,7 +314,8 @@ class EmbedIOProcessor(PoolingIOProcessor):
image_url=ImageURL(url=item.image_url["url"]), image_url=ImageURL(url=item.image_url["url"]),
) )
) )
return [CustomChatCompletionMessageParam(role="user", content=parts)] messages.append(CustomChatCompletionMessageParam(role="user", content=parts))
return messages
@staticmethod @staticmethod
def _check_cohere_max_tokens( def _check_cohere_max_tokens(
...@@ -346,9 +363,11 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -346,9 +363,11 @@ class EmbedIOProcessor(PoolingIOProcessor):
def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None: def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
"""Convert a ``CohereEmbedRequest`` into engine prompts. """Convert a ``CohereEmbedRequest`` into engine prompts.
For texts, a single batched completion request path is used. If a model has a chat template the task instruction are rendered
For images and mixed inputs, conversations are batch-rendered as a system prompt. Otherwise they are just prepended to the input text.
through the chat template in one ``render_chat`` call.
Images and mixed inputs are always batch-rendered through the chat
template in one ``render_chat`` call.
""" """
request = ctx.request request = ctx.request
assert isinstance(request, CohereEmbedRequest) assert isinstance(request, CohereEmbedRequest)
...@@ -363,42 +382,91 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -363,42 +382,91 @@ class EmbedIOProcessor(PoolingIOProcessor):
self._validate_input_type(input_type) self._validate_input_type(input_type)
if request.images is not None: if request.images is not None:
all_messages: list[list[ChatCompletionMessageParam]] = [ input: list[CohereEmbedInput] = [
[ CohereEmbedInput(
CustomChatCompletionMessageParam( content=[
role="user", CohereEmbedContent(type="image_url", image_url={"url": uri})
content=[{"type": "image_url", "image_url": {"url": uri}}], ]
) )
]
for uri in request.images for uri in request.images
] ]
ctx.engine_inputs = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side
)
elif request.inputs is not None: elif request.inputs is not None:
input = request.inputs
else:
texts = request.texts or []
task_prefix = self._get_task_instruction_prefix(input_type) task_prefix = self._get_task_instruction_prefix(input_type)
if task_prefix is None:
ctx.engine_inputs = self._preprocess_cohere_text_completion(
request,
texts,
truncate_prompt_tokens,
truncation_side,
)
return
all_messages = [ all_messages = [
self._mixed_input_to_messages(inp, task_prefix=task_prefix) self._mixed_input_to_messages(
for inp in request.inputs CohereEmbedInput(
content=[CohereEmbedContent(type="text", text=text)]
),
task_prefix=task_prefix,
)
for text in texts
] ]
ctx.engine_inputs = self._batch_render_chat( if self._has_chat_template():
request, all_messages, truncate_prompt_tokens, truncation_side ctx.engine_inputs = self._batch_render_chat(
) request,
all_messages,
truncate_prompt_tokens,
truncation_side,
)
else:
ctx.engine_inputs = self._preprocess_cohere_text_completion(
request,
self._apply_task_instruction(texts, input_type),
truncate_prompt_tokens,
truncation_side,
)
return
else: task_prefix = self._get_task_instruction_prefix(input_type)
prefixed = self._apply_task_instruction(request.texts or [], input_type) all_messages = [
proxy = EmbeddingCompletionRequest( self._mixed_input_to_messages(inp, task_prefix=task_prefix) for inp in input
model=request.model, ]
input=prefixed, ctx.engine_inputs = self._batch_render_chat(
dimensions=request.output_dimension, request, all_messages, truncate_prompt_tokens, truncation_side
encoding_format="float", )
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side, def _has_chat_template(self) -> bool:
) return (
ctx.engine_inputs = self._preprocess_completion_online( resolve_chat_template(
proxy, prompt_input=proxy.input, prompt_embeds=None self.renderer.tokenizer,
chat_template=self.chat_template,
tools=None,
model_config=self.model_config,
) )
is not None
)
def _preprocess_cohere_text_completion(
self,
request: CohereEmbedRequest,
texts: list[str],
truncate_prompt_tokens: int | None,
truncation_side: Literal["left", "right"] | None,
) -> list[EngineInput]:
proxy = EmbeddingCompletionRequest(
model=request.model,
input=texts,
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
return self._preprocess_completion_online(
proxy, prompt_input=proxy.input, prompt_embeds=None
)
def _batch_render_chat( def _batch_render_chat(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment