Unverified Commit edbc1abd authored by Jesus Federico's avatar Jesus Federico Committed by GitHub
Browse files

feat: add max_tokens_per_doc in rerank request. (#38827)


Signed-off-by: default avatarJesus Federico <jefp@amazon.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 0e39202c
......@@ -112,6 +112,35 @@ def test_classify(llm):
assert len(outputs[0].outputs.data) == 1
@pytest.mark.skip_global_cleanup
def test_max_tokens_per_doc(llm: LLM):
"""Test max_tokens_per_doc via PoolingParams.extra_kwargs (offline)."""
long_doc = "The capital of France is Paris. " * 20
# Without truncation
outputs_no_limit = llm.score(
TEXTS_1[0],
long_doc,
use_tqdm=False,
)
# With truncation via extra_kwargs
outputs_with_limit = llm.score(
TEXTS_1[0],
long_doc,
pooling_params=PoolingParams(extra_kwargs={"max_tokens_per_doc": 10}),
use_tqdm=False,
)
assert len(outputs_no_limit) == 1
assert len(outputs_with_limit) == 1
# Truncated version should have fewer prompt tokens
no_limit_tokens = len(outputs_no_limit[0].prompt_token_ids)
with_limit_tokens = len(outputs_with_limit[0].prompt_token_ids)
assert with_limit_tokens < no_limit_tokens
def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
outputs = llm.score(
......
......@@ -471,6 +471,78 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer):
assert len(poolings.data[0].data[0]) == 1
@pytest.mark.asyncio
async def test_rerank_max_tokens_per_doc(
server: RemoteOpenAIServer,
):
"""Test that max_tokens_per_doc actually reduces the token count."""
query = "What is the capital of France?"
# Use a doc that fits within max_model_len=100 (query ~8 tokens + 4 special)
long_doc = "The capital of France is Paris. " * 10 # ~70 tokens
# Without max_tokens_per_doc
response_no_limit = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": [long_doc],
"truncate_prompt_tokens": 99,
},
)
response_no_limit.raise_for_status()
rerank_no_limit = RerankResponse.model_validate(response_no_limit.json())
# With max_tokens_per_doc
response_with_limit = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": [long_doc],
"max_tokens_per_doc": 10,
},
)
response_with_limit.raise_for_status()
rerank_with_limit = RerankResponse.model_validate(response_with_limit.json())
assert rerank_with_limit.usage.prompt_tokens < rerank_no_limit.usage.prompt_tokens
@pytest.mark.asyncio
async def test_rerank_max_tokens_per_doc_validation(
server: RemoteOpenAIServer,
):
"""Test that max_tokens_per_doc validation works correctly."""
query = "What is the capital of France?"
documents = ["The capital of France is Paris."]
# Test with max_tokens_per_doc=0 (should succeed — means no truncation)
response = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": documents,
"max_tokens_per_doc": 0,
},
)
response.raise_for_status()
# Test with invalid max_tokens_per_doc (negative)
response = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": documents,
"max_tokens_per_doc": -5,
},
)
assert response.status_code == 400
assert "max_tokens_per_doc must be a non-negative integer" in response.text
@pytest.mark.asyncio
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for max_tokens_per_doc and max_tokens_per_query.
"""
import json
import os
from dataclasses import dataclass
import pytest
import requests
from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
TEMPLATE_DIR = str(VLLM_PATH / "examples/pooling/score/template")
long_query = "What is the capital of France?" * 20
long_doc = "The capital of France is Paris. " * 20
@dataclass
class TestConfig:
model: str
args: list[str]
without_truncated_prompt_tokens: int
with_max_tokens_per_query_prompt_tokens: int
with_max_tokens_per_doc_prompt_tokens: int
with_max_tokens_per_query_and_doc_prompt_tokens: int
RERANK_CONFIGS = [
# 1. cross-encoder
TestConfig(
model="jinaai/jina-reranker-v2-base-multilingual",
args=[
"--enforce-eager",
"--max-model-len",
"1024",
"--trust-remote-code",
],
without_truncated_prompt_tokens=284,
with_max_tokens_per_query_prompt_tokens=154,
with_max_tokens_per_doc_prompt_tokens=154,
with_max_tokens_per_query_and_doc_prompt_tokens=24,
),
# 2. cross-encoder + score template
TestConfig(
model="Qwen/Qwen3-Reranker-0.6B",
args=[
"--enforce-eager",
"--max-model-len",
"1024",
"--hf-overrides",
json.dumps(
{
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
),
"--chat-template",
os.path.join(TEMPLATE_DIR, "qwen3_reranker.jinja"),
],
without_truncated_prompt_tokens=352,
with_max_tokens_per_query_prompt_tokens=223,
with_max_tokens_per_doc_prompt_tokens=221,
with_max_tokens_per_query_and_doc_prompt_tokens=92,
),
# 3. bi-encoder
TestConfig(
model="intfloat/multilingual-e5-small",
args=[
"--enforce-eager",
"--max-model-len",
"512",
"--trust-remote-code",
],
without_truncated_prompt_tokens=286,
with_max_tokens_per_query_prompt_tokens=156,
with_max_tokens_per_doc_prompt_tokens=155,
with_max_tokens_per_query_and_doc_prompt_tokens=25,
),
# 4. late-interaction
TestConfig(
model="answerdotai/answerai-colbert-small-v1",
args=[
"--enforce-eager",
"--max-model-len",
"512",
"--trust-remote-code",
],
without_truncated_prompt_tokens=285,
with_max_tokens_per_query_prompt_tokens=155,
with_max_tokens_per_doc_prompt_tokens=155,
with_max_tokens_per_query_and_doc_prompt_tokens=25,
),
# 5. jinaai/jina-reranker-v3
TestConfig(
model="jinaai/jina-reranker-v3",
args=[
"--enforce-eager",
"--max-model-len",
"1024",
"--trust-remote-code",
],
without_truncated_prompt_tokens=567,
with_max_tokens_per_query_prompt_tokens=308,
with_max_tokens_per_doc_prompt_tokens=436,
with_max_tokens_per_query_and_doc_prompt_tokens=177,
),
]
@pytest.fixture(scope="module", params=RERANK_CONFIGS, ids=lambda c: c.model)
def server(request):
config: TestConfig = request.param
with RemoteOpenAIServer(config.model, config.args) as remote_server:
yield config, remote_server
def test_without_truncated(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={"model": config.model, "query": long_query, "documents": [long_doc]},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert rerank.usage.prompt_tokens == config.without_truncated_prompt_tokens
def test_max_tokens_per_query(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={
"model": config.model,
"query": long_query,
"documents": [long_doc],
"max_tokens_per_query": 10,
},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert rerank.usage.prompt_tokens == config.with_max_tokens_per_query_prompt_tokens
def test_max_tokens_per_doc(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={
"model": config.model,
"query": long_query,
"documents": [long_doc],
"max_tokens_per_doc": 10,
},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert rerank.usage.prompt_tokens == config.with_max_tokens_per_doc_prompt_tokens
def test_max_tokens_per_query_and_doc(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={
"model": config.model,
"query": long_query,
"documents": [long_doc],
"max_tokens_per_query": 10,
"max_tokens_per_doc": 10,
},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert (
rerank.usage.prompt_tokens
== config.with_max_tokens_per_query_and_doc_prompt_tokens
)
......@@ -232,6 +232,7 @@ class PoolingServingBase(ABC):
"greater than max_model_len."
" Please request a smaller truncation size."
)
return None
async def _get_trace_headers(
......
......@@ -25,8 +25,10 @@ from .typing import ScoreData, ScoreInput, ScoringData
from .utils import (
compress_token_type_ids,
compute_maxsim_score,
get_num_special_tokens_for_pair,
parse_score_data,
score_data_to_prompts,
truncate_text_to_tokens,
validate_score_input,
)
......@@ -48,6 +50,64 @@ class ScoringIOProcessor(PoolingIOProcessor):
def create_pooling_params(self, request):
return request.to_pooling_params(self.pooling_task)
def _validate_token_limit(self, value: int, name: str) -> None:
if value < 0:
raise ValueError(f"{name} must be a non-negative integer")
if value >= self.model_config.max_model_len:
raise ValueError(
f"{name} ({value}) must be less "
f"than max_model_len ({self.model_config.max_model_len})."
)
def _get_token_limits(
self,
request: ScoringRequest | None = None,
pooling_params: PoolingParams | None = None,
) -> tuple[int, int]:
"""Extract and validate token limits from request or pooling_params."""
if request is not None:
max_tokens_per_query = getattr(request, "max_tokens_per_query", 0)
max_tokens_per_doc = getattr(request, "max_tokens_per_doc", 0)
else:
extra = (
(pooling_params.extra_kwargs or {})
if pooling_params is not None
else {}
)
max_tokens_per_query = extra.get("max_tokens_per_query", 0)
max_tokens_per_doc = extra.get("max_tokens_per_doc", 0)
if max_tokens_per_query != 0:
self._validate_token_limit(max_tokens_per_query, "max_tokens_per_query")
if max_tokens_per_doc != 0:
self._validate_token_limit(max_tokens_per_doc, "max_tokens_per_doc")
return max_tokens_per_query, max_tokens_per_doc
def _truncate_scoring_data(
self,
scoring_data: ScoringData,
max_tokens_per_query: int = 0,
max_tokens_per_doc: int = 0,
) -> ScoringData:
"""Truncate query/document texts to token limits."""
data_1 = scoring_data.data_1
data_2 = scoring_data.data_2
if max_tokens_per_query > 0:
data_1 = [
truncate_text_to_tokens(d, self.tokenizer, max_tokens_per_query)
if isinstance(d, str)
else d
for d in data_1
]
if max_tokens_per_doc > 0:
data_2 = [
truncate_text_to_tokens(d, self.tokenizer, max_tokens_per_doc)
if isinstance(d, str)
else d
for d in data_2
]
return ScoringData(data_1=data_1, data_2=data_2)
def valid_inputs(
self,
data_1: ScoreInput | list[ScoreInput],
......@@ -82,6 +142,15 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
request=request
)
if max_tokens_per_query > 0 or max_tokens_per_doc > 0:
scoring_data = self._truncate_scoring_data(
scoring_data, max_tokens_per_query, max_tokens_per_doc
)
tok_params = request.build_tok_params(self.model_config)
engine_inputs = self._pre_process(
scoring_data,
......@@ -112,10 +181,23 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.pooling_params, Sequence)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._pre_process(ctx.prompts, tok_params)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
pooling_params=ctx.pooling_params
)
scoring_data = ctx.prompts
if max_tokens_per_query > 0 or max_tokens_per_doc > 0:
scoring_data = self._truncate_scoring_data(
scoring_data, max_tokens_per_query, max_tokens_per_doc
)
return self._pre_process(scoring_data, tok_params)
def post_process_offline(
self,
......@@ -217,8 +299,38 @@ class LateInteractionIOProcessor(BiEncoderIOProcessor):
class FlashLateInteractionIOProcessor(LateInteractionIOProcessor):
name = "flash-late-interaction"
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
return outputs
def post_process_online(
self,
ctx: ScoringServeContext,
):
assert ctx.query_final_res_batch is not None
assert ctx.final_res_batch is not None
assert isinstance(ctx.n_queries, int)
# Expand queries if 1:N scoring
if len(ctx.query_final_res_batch) == 1:
ctx.query_final_res_batch = ctx.query_final_res_batch * len(
ctx.final_res_batch
)
final_res_batch: list[PoolingRequestOutput] = []
for d1, d2 in zip(ctx.query_final_res_batch, ctx.final_res_batch):
padding: list[int] = []
if (pad_token_id := self.pad_token_id) is not None:
padding = [pad_token_id]
tokens = d1.prompt_token_ids + padding + d2.prompt_token_ids
final_res_batch.append(
PoolingRequestOutput(
request_id=f"{d1.request_id}_{d2.request_id}",
outputs=d2.outputs,
prompt_token_ids=tokens,
num_cached_tokens=d1.num_cached_tokens + d2.num_cached_tokens,
finished=True,
)
)
ctx.final_res_batch = final_res_batch
class CrossEncoderIOProcessor(ScoringIOProcessor):
......@@ -255,6 +367,11 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
request=request
)
tok_params = request.build_tok_params(self.model_config)
pooling_params = self.create_pooling_params(request)
......@@ -263,6 +380,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
tok_params,
pooling_params,
chat_template=self.chat_template,
max_tokens_per_query=max_tokens_per_query,
max_tokens_per_doc=max_tokens_per_doc,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
......@@ -283,8 +402,18 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
pooling_params=ctx.pooling_params
)
engine_inputs, pooling_params_list = self._pre_process(
ctx.prompts, tok_params, ctx.pooling_params, ctx.chat_template
ctx.prompts,
tok_params,
ctx.pooling_params,
ctx.chat_template,
max_tokens_per_query=max_tokens_per_query,
max_tokens_per_doc=max_tokens_per_doc,
)
ctx.pooling_params = pooling_params_list
return engine_inputs
......@@ -298,6 +427,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
tok_params: TokenizeParams,
pooling_params: PoolingParams | None,
chat_template: str | None = None,
max_tokens_per_query: int = 0,
max_tokens_per_doc: int = 0,
prompt_extras: dict[str, Any] | None = None,
) -> tuple[Sequence[EngineInput], list[PoolingParams]]:
# todo: support prompt_extras
......@@ -320,6 +451,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
data_2=d,
encode_kwargs=tok_params.get_encode_kwargs(),
chat_template=chat_template,
max_tokens_per_query=max_tokens_per_query,
max_tokens_per_doc=max_tokens_per_doc,
)
if token_type_ids := engine_prompt.pop("token_type_ids", None):
......@@ -342,6 +475,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
data_2: ScoreData,
encode_kwargs: dict[str, Any],
chat_template: str | None = None,
max_tokens_per_query: int = 0,
max_tokens_per_doc: int = 0,
):
model_config = self.model_config
tokenizer = self.tokenizer
......@@ -352,25 +487,61 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
model_config,
)
# Apply truncation before defining closures
if max_tokens_per_query > 0 and isinstance(prompt_1, str):
prompt_1 = truncate_text_to_tokens(
prompt_1, tokenizer, max_tokens_per_query
)
if max_tokens_per_doc > 0 and isinstance(prompt_2, str):
prompt_2 = truncate_text_to_tokens(prompt_2, tokenizer, max_tokens_per_doc)
def default_tokenizer_encode():
local_kwargs = encode_kwargs.copy()
if self.supports_score_template:
assert self.model is not None
full_prompt = self.model.get_score_template(prompt_1, prompt_2)
if full_prompt is None:
raise ValueError("Get empty score template from model")
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
prompt_inputs = tokenizer(full_prompt, **local_kwargs)
else:
if self.use_sep_token:
# cross_encoder models defaults to using separating token.
if max_tokens_per_doc > 0 and isinstance(prompt_2, str):
query_tokens = tokenizer.encode(
prompt_1, add_special_tokens=False
)
num_special = get_num_special_tokens_for_pair(tokenizer)
doc_limit_max_length = (
len(query_tokens) + max_tokens_per_doc + num_special
)
existing_max_length = local_kwargs.get("max_length")
if existing_max_length is not None:
effective_max_length = min(
doc_limit_max_length, existing_max_length
)
else:
effective_max_length = doc_limit_max_length
local_kwargs["truncation"] = "only_second"
local_kwargs["max_length"] = effective_max_length
prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **encode_kwargs
text=prompt_1, text_pair=prompt_2, **local_kwargs
)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` defaults to not using separating token.
if max_tokens_per_doc > 0 and isinstance(prompt_2, str):
query_ids = tokenizer.encode(prompt_1, add_special_tokens=False)
doc_ids = tokenizer.encode(prompt_2, add_special_tokens=False)
doc_ids = doc_ids[:max_tokens_per_doc]
input_ids = query_ids + doc_ids
full_prompt = tokenizer.decode(input_ids)
prompt_inputs = {"input_ids": input_ids}
else:
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **encode_kwargs)
prompt_inputs = tokenizer(text=full_prompt, **local_kwargs)
return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
......
......@@ -20,6 +20,24 @@ from .typing import ScoreContentPartParam, ScoreInput
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_tokens_per_query: int = Field(
default=0,
description=(
"Maximum number of tokens per query. Queries longer than "
"this will be truncated to this length. 0 means no "
"query-level truncation is applied."
),
)
max_tokens_per_doc: int = Field(
default=0,
description=(
"Maximum number of tokens per document. Documents longer than "
"this will be truncated to this length. 0 means no "
"document-level truncation is applied (only truncate_prompt_tokens "
"applies to the combined query+document)."
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
......@@ -91,29 +109,11 @@ ScoreRequest: TypeAlias = (
)
class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
class RerankRequest(ScoreRequestMixin):
query: ScoreInput
documents: ScoreInput | list[ScoreInput]
top_n: int = Field(default_factory=lambda: 0)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self, task: PoolingTask = "classify"):
return PoolingParams(
task=task,
use_activation=self.use_activation,
)
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
......
......@@ -237,6 +237,7 @@ class ServingScores(PoolingServing):
await self._prepare_generators(query_ctx)
await self._collect_batch(query_ctx)
ctx.query_final_res_batch = query_ctx.final_res_batch
async def _flash_late_interaction_encode_docs(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None
......
......@@ -25,6 +25,37 @@ from .typing import (
)
def get_num_special_tokens_for_pair(tokenizer) -> int:
"""Get number of special tokens added for a text pair encoding."""
method = getattr(tokenizer, "num_special_tokens_to_add", None)
if method is not None:
try:
return method(pair=True)
except TypeError:
pass
# Fallback: compute by tokenizing empty strings
empty_encoding = tokenizer("", text_pair="", add_special_tokens=True)
return len(empty_encoding["input_ids"])
def truncate_text_to_tokens(
text: str,
tokenizer,
max_tokens: int,
) -> str:
"""Truncate text to a maximum number of content tokens.
Uses offset_mapping to slice the original text at the exact character
boundary, avoiding lossy encode→decode round-trips that can shift
the token count by 1-3 tokens due to BPE merge boundary changes.
"""
encoding = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
if len(encoding["input_ids"]) <= max_tokens:
return text
char_end = encoding["offset_mapping"][max_tokens - 1][1]
return text[:char_end]
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
"""
Compute ColBERT MaxSim score.
......
......@@ -89,6 +89,9 @@ class PoolingServeContext(Generic[PoolingRequestT]):
## for IOProcessorResponse
response: Any | None = None
## for flash-late-interaction
query_final_res_batch: list[PoolingRequestOutput] | None = None
@dataclass
class OfflineInputsContext:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment