"vscode:/vscode.git/clone" did not exist on "62d54ba46db25b95de2d21e46f4b57b5502ed747"
Unverified Commit fafca38a authored by Walter Beller-Morales's avatar Walter Beller-Morales Committed by GitHub
Browse files

[BugFix][Frontend] apply task instruction as system prompt in cohere v2/embed (#38362)


Signed-off-by: default avatarwalterbm <walter.beller.morales@gmail.com>
parent aa4eb0db
......@@ -57,16 +57,25 @@ def _openai_embed(
return [item["embedding"] for item in resp.json()["data"]]
def _cosine_sim(a: list[float], b: list[float]) -> float:
va, vb = np.array(a), np.array(b)
return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))
def test_single_text_parity(server: RemoteOpenAIServer):
"""A single text should produce identical embeddings via both APIs."""
"""A single text should produce equivalent embeddings via both APIs."""
texts = ["the quick brown fox jumps over the lazy dog"]
v2 = _cohere_embed(server, texts)
v1 = _openai_embed(server, texts)
np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5)
# Full-suite BF16 runs can introduce tiny numerical drift even when both
# endpoints are functionally equivalent, so compare semantic equivalence
# instead of exact elementwise equality.
cos = _cosine_sim(v2[0], v1[0])
assert cos > 0.9999, f"single-text parity failed, cosine={cos}"
def test_batch_parity(server: RemoteOpenAIServer):
"""A batch of texts should produce identical embeddings via both APIs,
"""A batch of texts should produce equivalent embeddings via both APIs,
in the same order."""
texts = [
"machine learning",
......@@ -76,8 +85,18 @@ def test_batch_parity(server: RemoteOpenAIServer):
v2 = _cohere_embed(server, texts)
v1 = _openai_embed(server, texts)
assert len(v2) == len(v1) == 3
similarities = np.array(
[[_cosine_sim(v2_emb, v1_emb) for v1_emb in v1] for v2_emb in v2]
)
for i in range(3):
np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}")
assert int(np.argmax(similarities[i])) == i, (
f"batch parity order mismatch at index {i}: "
f"similarities={similarities[i].tolist()}"
)
assert similarities[i, i] > 0.9999, (
f"batch parity failed at index {i}, cosine={similarities[i, i]}"
)
def test_token_count_parity(server: RemoteOpenAIServer):
......
......@@ -6,8 +6,11 @@ import pytest
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
CohereEmbedInput,
CohereEmbedRequest,
)
from vllm.entrypoints.pooling.typing import PoolingServeContext
class TestResolveTruncation:
......@@ -206,3 +209,116 @@ class TestValidateInputType:
handler = self._make_handler({"a": "", "b": ""})
with pytest.raises(ValueError, match="Supported values: a, b"):
handler._validate_input_type("z")
class TestPreProcessCohereOnline:
"""Unit tests for EmbedIOProcessor._pre_process_cohere_online."""
@staticmethod
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
return PoolingServeContext(
request=CohereEmbedRequest(model="test", **request_kwargs),
model_name="test",
request_id="embd-test",
)
@staticmethod
def _make_handler():
handler = object.__new__(EmbedIOProcessor)
handler._validate_input_type = lambda _input_type: None
return handler
def test_text_only_without_task_prefix_uses_completion_path(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"])
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["completion"]
handler._get_task_instruction_prefix = lambda _input_type: None
handler._has_chat_template = lambda: False
handler._preprocess_completion_online = preprocess_completion
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("text-only request should not require chat rendering")
)
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["completion"]
assert calls == [("completion", ["hello"])]
def test_text_only_falls_back_to_prefixed_completion_without_template(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["fallback"]
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: False
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("chat rendering should be skipped without a template")
)
handler._preprocess_completion_online = preprocess_completion
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["fallback"]
assert calls == [("completion", ["query: hello"])]
def test_text_only_with_template_uses_chat_path(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def batch_render_chat(
request,
all_messages,
truncate_prompt_tokens,
truncation_side,
):
calls.append(
(
"chat",
{
"request": request,
"all_messages": all_messages,
"truncate_prompt_tokens": truncate_prompt_tokens,
"truncation_side": truncation_side,
},
)
)
return ["chat"]
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: True
handler._batch_render_chat = batch_render_chat
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
pytest.fail("completion path should be skipped when a template exists")
)
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["chat"]
assert calls == [
(
"chat",
{
"request": ctx.request,
"all_messages": [
handler._mixed_input_to_messages(
CohereEmbedInput(
content=[CohereEmbedContent(type="text", text="hello")]
),
task_prefix="query: ",
)
],
"truncate_prompt_tokens": -1,
"truncation_side": None,
},
)
]
......@@ -18,6 +18,7 @@ from vllm.entrypoints.chat_utils import (
)
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
CohereEmbedInput,
CohereEmbedRequest,
EmbeddingChatRequest,
......@@ -28,6 +29,7 @@ from vllm.inputs import EngineInput, tokens_input
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.renderers import merge_kwargs
from vllm.renderers.hf import resolve_chat_template
from vllm.utils.collection_utils import chunk_list
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -284,13 +286,27 @@ class EmbedIOProcessor(PoolingIOProcessor):
) -> list[ChatCompletionMessageParam]:
"""Build chat messages from a mixed text+image input.
When *task_prefix* is given, it is prepended to each text part.
When *task_prefix* is given, it is used as the system prompt.
"""
messages: list[ChatCompletionMessageParam] = []
if task_prefix is not None:
messages.append(
CustomChatCompletionMessageParam(
role="system",
content=[
ChatCompletionContentPartTextParam(
type="text", text=task_prefix
)
],
)
)
parts: list[ChatCompletionContentPartParam] = []
for item in inp.content:
if item.type == "text" and item.text is not None:
text = task_prefix + item.text if task_prefix else item.text
parts.append(ChatCompletionContentPartTextParam(type="text", text=text))
parts.append(
ChatCompletionContentPartTextParam(type="text", text=item.text)
)
elif item.type == "image_url" and item.image_url is not None:
parts.append(
ChatCompletionContentPartImageParam(
......@@ -298,7 +314,8 @@ class EmbedIOProcessor(PoolingIOProcessor):
image_url=ImageURL(url=item.image_url["url"]),
)
)
return [CustomChatCompletionMessageParam(role="user", content=parts)]
messages.append(CustomChatCompletionMessageParam(role="user", content=parts))
return messages
@staticmethod
def _check_cohere_max_tokens(
......@@ -346,9 +363,11 @@ class EmbedIOProcessor(PoolingIOProcessor):
def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
"""Convert a ``CohereEmbedRequest`` into engine prompts.
For texts, a single batched completion request path is used.
For images and mixed inputs, conversations are batch-rendered
through the chat template in one ``render_chat`` call.
If a model has a chat template the task instruction are rendered
as a system prompt. Otherwise they are just prepended to the input text.
Images and mixed inputs are always batch-rendered through the chat
template in one ``render_chat`` call.
"""
request = ctx.request
assert isinstance(request, CohereEmbedRequest)
......@@ -363,42 +382,91 @@ class EmbedIOProcessor(PoolingIOProcessor):
self._validate_input_type(input_type)
if request.images is not None:
all_messages: list[list[ChatCompletionMessageParam]] = [
[
CustomChatCompletionMessageParam(
role="user",
content=[{"type": "image_url", "image_url": {"url": uri}}],
)
]
input: list[CohereEmbedInput] = [
CohereEmbedInput(
content=[
CohereEmbedContent(type="image_url", image_url={"url": uri})
]
)
for uri in request.images
]
ctx.engine_inputs = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side
)
elif request.inputs is not None:
input = request.inputs
else:
texts = request.texts or []
task_prefix = self._get_task_instruction_prefix(input_type)
if task_prefix is None:
ctx.engine_inputs = self._preprocess_cohere_text_completion(
request,
texts,
truncate_prompt_tokens,
truncation_side,
)
return
all_messages = [
self._mixed_input_to_messages(inp, task_prefix=task_prefix)
for inp in request.inputs
self._mixed_input_to_messages(
CohereEmbedInput(
content=[CohereEmbedContent(type="text", text=text)]
),
task_prefix=task_prefix,
)
for text in texts
]
ctx.engine_inputs = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side
)
if self._has_chat_template():
ctx.engine_inputs = self._batch_render_chat(
request,
all_messages,
truncate_prompt_tokens,
truncation_side,
)
else:
ctx.engine_inputs = self._preprocess_cohere_text_completion(
request,
self._apply_task_instruction(texts, input_type),
truncate_prompt_tokens,
truncation_side,
)
return
else:
prefixed = self._apply_task_instruction(request.texts or [], input_type)
proxy = EmbeddingCompletionRequest(
model=request.model,
input=prefixed,
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
ctx.engine_inputs = self._preprocess_completion_online(
proxy, prompt_input=proxy.input, prompt_embeds=None
task_prefix = self._get_task_instruction_prefix(input_type)
all_messages = [
self._mixed_input_to_messages(inp, task_prefix=task_prefix) for inp in input
]
ctx.engine_inputs = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side
)
def _has_chat_template(self) -> bool:
return (
resolve_chat_template(
self.renderer.tokenizer,
chat_template=self.chat_template,
tools=None,
model_config=self.model_config,
)
is not None
)
def _preprocess_cohere_text_completion(
self,
request: CohereEmbedRequest,
texts: list[str],
truncate_prompt_tokens: int | None,
truncation_side: Literal["left", "right"] | None,
) -> list[EngineInput]:
proxy = EmbeddingCompletionRequest(
model=request.model,
input=texts,
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
return self._preprocess_completion_online(
proxy, prompt_input=proxy.input, prompt_embeds=None
)
def _batch_render_chat(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment