Unverified Commit edbc1abd authored by Jesus Federico's avatar Jesus Federico Committed by GitHub
Browse files

feat: add max_tokens_per_doc in rerank request. (#38827)


Signed-off-by: default avatarJesus Federico <jefp@amazon.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 0e39202c
...@@ -112,6 +112,35 @@ def test_classify(llm): ...@@ -112,6 +112,35 @@ def test_classify(llm):
assert len(outputs[0].outputs.data) == 1 assert len(outputs[0].outputs.data) == 1
@pytest.mark.skip_global_cleanup
def test_max_tokens_per_doc(llm: LLM):
"""Test max_tokens_per_doc via PoolingParams.extra_kwargs (offline)."""
long_doc = "The capital of France is Paris. " * 20
# Without truncation
outputs_no_limit = llm.score(
TEXTS_1[0],
long_doc,
use_tqdm=False,
)
# With truncation via extra_kwargs
outputs_with_limit = llm.score(
TEXTS_1[0],
long_doc,
pooling_params=PoolingParams(extra_kwargs={"max_tokens_per_doc": 10}),
use_tqdm=False,
)
assert len(outputs_no_limit) == 1
assert len(outputs_with_limit) == 1
# Truncated version should have fewer prompt tokens
no_limit_tokens = len(outputs_no_limit[0].prompt_token_ids)
with_limit_tokens = len(outputs_with_limit[0].prompt_token_ids)
assert with_limit_tokens < no_limit_tokens
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(use_activation): def get_outputs(use_activation):
outputs = llm.score( outputs = llm.score(
......
...@@ -471,6 +471,78 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer): ...@@ -471,6 +471,78 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer):
assert len(poolings.data[0].data[0]) == 1 assert len(poolings.data[0].data[0]) == 1
@pytest.mark.asyncio
async def test_rerank_max_tokens_per_doc(
server: RemoteOpenAIServer,
):
"""Test that max_tokens_per_doc actually reduces the token count."""
query = "What is the capital of France?"
# Use a doc that fits within max_model_len=100 (query ~8 tokens + 4 special)
long_doc = "The capital of France is Paris. " * 10 # ~70 tokens
# Without max_tokens_per_doc
response_no_limit = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": [long_doc],
"truncate_prompt_tokens": 99,
},
)
response_no_limit.raise_for_status()
rerank_no_limit = RerankResponse.model_validate(response_no_limit.json())
# With max_tokens_per_doc
response_with_limit = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": [long_doc],
"max_tokens_per_doc": 10,
},
)
response_with_limit.raise_for_status()
rerank_with_limit = RerankResponse.model_validate(response_with_limit.json())
assert rerank_with_limit.usage.prompt_tokens < rerank_no_limit.usage.prompt_tokens
@pytest.mark.asyncio
async def test_rerank_max_tokens_per_doc_validation(
server: RemoteOpenAIServer,
):
"""Test that max_tokens_per_doc validation works correctly."""
query = "What is the capital of France?"
documents = ["The capital of France is Paris."]
# Test with max_tokens_per_doc=0 (should succeed — means no truncation)
response = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": documents,
"max_tokens_per_doc": 0,
},
)
response.raise_for_status()
# Test with invalid max_tokens_per_doc (negative)
response = requests.post(
server.url_for("rerank"),
json={
"model": MODEL_NAME,
"query": query,
"documents": documents,
"max_tokens_per_doc": -5,
},
)
assert response.status_code == 400
assert "max_tokens_per_doc must be a non-negative integer" in response.text
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) @pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str): async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for max_tokens_per_doc and max_tokens_per_query.
"""
import json
import os
from dataclasses import dataclass
import pytest
import requests
from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
TEMPLATE_DIR = str(VLLM_PATH / "examples/pooling/score/template")
long_query = "What is the capital of France?" * 20
long_doc = "The capital of France is Paris. " * 20
@dataclass
class TestConfig:
model: str
args: list[str]
without_truncated_prompt_tokens: int
with_max_tokens_per_query_prompt_tokens: int
with_max_tokens_per_doc_prompt_tokens: int
with_max_tokens_per_query_and_doc_prompt_tokens: int
RERANK_CONFIGS = [
# 1. cross-encoder
TestConfig(
model="jinaai/jina-reranker-v2-base-multilingual",
args=[
"--enforce-eager",
"--max-model-len",
"1024",
"--trust-remote-code",
],
without_truncated_prompt_tokens=284,
with_max_tokens_per_query_prompt_tokens=154,
with_max_tokens_per_doc_prompt_tokens=154,
with_max_tokens_per_query_and_doc_prompt_tokens=24,
),
# 2. cross-encoder + score template
TestConfig(
model="Qwen/Qwen3-Reranker-0.6B",
args=[
"--enforce-eager",
"--max-model-len",
"1024",
"--hf-overrides",
json.dumps(
{
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
),
"--chat-template",
os.path.join(TEMPLATE_DIR, "qwen3_reranker.jinja"),
],
without_truncated_prompt_tokens=352,
with_max_tokens_per_query_prompt_tokens=223,
with_max_tokens_per_doc_prompt_tokens=221,
with_max_tokens_per_query_and_doc_prompt_tokens=92,
),
# 3. bi-encoder
TestConfig(
model="intfloat/multilingual-e5-small",
args=[
"--enforce-eager",
"--max-model-len",
"512",
"--trust-remote-code",
],
without_truncated_prompt_tokens=286,
with_max_tokens_per_query_prompt_tokens=156,
with_max_tokens_per_doc_prompt_tokens=155,
with_max_tokens_per_query_and_doc_prompt_tokens=25,
),
# 4. late-interaction
TestConfig(
model="answerdotai/answerai-colbert-small-v1",
args=[
"--enforce-eager",
"--max-model-len",
"512",
"--trust-remote-code",
],
without_truncated_prompt_tokens=285,
with_max_tokens_per_query_prompt_tokens=155,
with_max_tokens_per_doc_prompt_tokens=155,
with_max_tokens_per_query_and_doc_prompt_tokens=25,
),
# 5. jinaai/jina-reranker-v3
TestConfig(
model="jinaai/jina-reranker-v3",
args=[
"--enforce-eager",
"--max-model-len",
"1024",
"--trust-remote-code",
],
without_truncated_prompt_tokens=567,
with_max_tokens_per_query_prompt_tokens=308,
with_max_tokens_per_doc_prompt_tokens=436,
with_max_tokens_per_query_and_doc_prompt_tokens=177,
),
]
@pytest.fixture(scope="module", params=RERANK_CONFIGS, ids=lambda c: c.model)
def server(request):
config: TestConfig = request.param
with RemoteOpenAIServer(config.model, config.args) as remote_server:
yield config, remote_server
def test_without_truncated(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={"model": config.model, "query": long_query, "documents": [long_doc]},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert rerank.usage.prompt_tokens == config.without_truncated_prompt_tokens
def test_max_tokens_per_query(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={
"model": config.model,
"query": long_query,
"documents": [long_doc],
"max_tokens_per_query": 10,
},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert rerank.usage.prompt_tokens == config.with_max_tokens_per_query_prompt_tokens
def test_max_tokens_per_doc(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={
"model": config.model,
"query": long_query,
"documents": [long_doc],
"max_tokens_per_doc": 10,
},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert rerank.usage.prompt_tokens == config.with_max_tokens_per_doc_prompt_tokens
def test_max_tokens_per_query_and_doc(server):
"""Test that max_tokens_per_doc truncates documents correctly."""
config, remote_server = server
response = requests.post(
remote_server.url_for("rerank"),
json={
"model": config.model,
"query": long_query,
"documents": [long_doc],
"max_tokens_per_query": 10,
"max_tokens_per_doc": 10,
},
)
response.raise_for_status()
rerank = RerankResponse.model_validate(response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 1
assert (
rerank.usage.prompt_tokens
== config.with_max_tokens_per_query_and_doc_prompt_tokens
)
...@@ -232,6 +232,7 @@ class PoolingServingBase(ABC): ...@@ -232,6 +232,7 @@ class PoolingServingBase(ABC):
"greater than max_model_len." "greater than max_model_len."
" Please request a smaller truncation size." " Please request a smaller truncation size."
) )
return None return None
async def _get_trace_headers( async def _get_trace_headers(
......
...@@ -25,8 +25,10 @@ from .typing import ScoreData, ScoreInput, ScoringData ...@@ -25,8 +25,10 @@ from .typing import ScoreData, ScoreInput, ScoringData
from .utils import ( from .utils import (
compress_token_type_ids, compress_token_type_ids,
compute_maxsim_score, compute_maxsim_score,
get_num_special_tokens_for_pair,
parse_score_data, parse_score_data,
score_data_to_prompts, score_data_to_prompts,
truncate_text_to_tokens,
validate_score_input, validate_score_input,
) )
...@@ -48,6 +50,64 @@ class ScoringIOProcessor(PoolingIOProcessor): ...@@ -48,6 +50,64 @@ class ScoringIOProcessor(PoolingIOProcessor):
def create_pooling_params(self, request): def create_pooling_params(self, request):
return request.to_pooling_params(self.pooling_task) return request.to_pooling_params(self.pooling_task)
def _validate_token_limit(self, value: int, name: str) -> None:
if value < 0:
raise ValueError(f"{name} must be a non-negative integer")
if value >= self.model_config.max_model_len:
raise ValueError(
f"{name} ({value}) must be less "
f"than max_model_len ({self.model_config.max_model_len})."
)
def _get_token_limits(
self,
request: ScoringRequest | None = None,
pooling_params: PoolingParams | None = None,
) -> tuple[int, int]:
"""Extract and validate token limits from request or pooling_params."""
if request is not None:
max_tokens_per_query = getattr(request, "max_tokens_per_query", 0)
max_tokens_per_doc = getattr(request, "max_tokens_per_doc", 0)
else:
extra = (
(pooling_params.extra_kwargs or {})
if pooling_params is not None
else {}
)
max_tokens_per_query = extra.get("max_tokens_per_query", 0)
max_tokens_per_doc = extra.get("max_tokens_per_doc", 0)
if max_tokens_per_query != 0:
self._validate_token_limit(max_tokens_per_query, "max_tokens_per_query")
if max_tokens_per_doc != 0:
self._validate_token_limit(max_tokens_per_doc, "max_tokens_per_doc")
return max_tokens_per_query, max_tokens_per_doc
def _truncate_scoring_data(
self,
scoring_data: ScoringData,
max_tokens_per_query: int = 0,
max_tokens_per_doc: int = 0,
) -> ScoringData:
"""Truncate query/document texts to token limits."""
data_1 = scoring_data.data_1
data_2 = scoring_data.data_2
if max_tokens_per_query > 0:
data_1 = [
truncate_text_to_tokens(d, self.tokenizer, max_tokens_per_query)
if isinstance(d, str)
else d
for d in data_1
]
if max_tokens_per_doc > 0:
data_2 = [
truncate_text_to_tokens(d, self.tokenizer, max_tokens_per_doc)
if isinstance(d, str)
else d
for d in data_2
]
return ScoringData(data_1=data_1, data_2=data_2)
def valid_inputs( def valid_inputs(
self, self,
data_1: ScoreInput | list[ScoreInput], data_1: ScoreInput | list[ScoreInput],
...@@ -82,6 +142,15 @@ class BiEncoderIOProcessor(ScoringIOProcessor): ...@@ -82,6 +142,15 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
raise ValueError(f"Invalid {self.name} request type") raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2) scoring_data = self.valid_inputs(data_1, data_2)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
request=request
)
if max_tokens_per_query > 0 or max_tokens_per_doc > 0:
scoring_data = self._truncate_scoring_data(
scoring_data, max_tokens_per_query, max_tokens_per_doc
)
tok_params = request.build_tok_params(self.model_config) tok_params = request.build_tok_params(self.model_config)
engine_inputs = self._pre_process( engine_inputs = self._pre_process(
scoring_data, scoring_data,
...@@ -112,10 +181,23 @@ class BiEncoderIOProcessor(ScoringIOProcessor): ...@@ -112,10 +181,23 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]: def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData) assert isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.pooling_params, Sequence)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {}) **(ctx.tokenization_kwargs or {})
) )
return self._pre_process(ctx.prompts, tok_params)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
pooling_params=ctx.pooling_params
)
scoring_data = ctx.prompts
if max_tokens_per_query > 0 or max_tokens_per_doc > 0:
scoring_data = self._truncate_scoring_data(
scoring_data, max_tokens_per_query, max_tokens_per_doc
)
return self._pre_process(scoring_data, tok_params)
def post_process_offline( def post_process_offline(
self, self,
...@@ -217,8 +299,38 @@ class LateInteractionIOProcessor(BiEncoderIOProcessor): ...@@ -217,8 +299,38 @@ class LateInteractionIOProcessor(BiEncoderIOProcessor):
class FlashLateInteractionIOProcessor(LateInteractionIOProcessor): class FlashLateInteractionIOProcessor(LateInteractionIOProcessor):
name = "flash-late-interaction" name = "flash-late-interaction"
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int): def post_process_online(
return outputs self,
ctx: ScoringServeContext,
):
assert ctx.query_final_res_batch is not None
assert ctx.final_res_batch is not None
assert isinstance(ctx.n_queries, int)
# Expand queries if 1:N scoring
if len(ctx.query_final_res_batch) == 1:
ctx.query_final_res_batch = ctx.query_final_res_batch * len(
ctx.final_res_batch
)
final_res_batch: list[PoolingRequestOutput] = []
for d1, d2 in zip(ctx.query_final_res_batch, ctx.final_res_batch):
padding: list[int] = []
if (pad_token_id := self.pad_token_id) is not None:
padding = [pad_token_id]
tokens = d1.prompt_token_ids + padding + d2.prompt_token_ids
final_res_batch.append(
PoolingRequestOutput(
request_id=f"{d1.request_id}_{d2.request_id}",
outputs=d2.outputs,
prompt_token_ids=tokens,
num_cached_tokens=d1.num_cached_tokens + d2.num_cached_tokens,
finished=True,
)
)
ctx.final_res_batch = final_res_batch
class CrossEncoderIOProcessor(ScoringIOProcessor): class CrossEncoderIOProcessor(ScoringIOProcessor):
...@@ -255,6 +367,11 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -255,6 +367,11 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
raise ValueError(f"Invalid {self.name} request type") raise ValueError(f"Invalid {self.name} request type")
scoring_data = self.valid_inputs(data_1, data_2) scoring_data = self.valid_inputs(data_1, data_2)
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
request=request
)
tok_params = request.build_tok_params(self.model_config) tok_params = request.build_tok_params(self.model_config)
pooling_params = self.create_pooling_params(request) pooling_params = self.create_pooling_params(request)
...@@ -263,6 +380,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -263,6 +380,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
tok_params, tok_params,
pooling_params, pooling_params,
chat_template=self.chat_template, chat_template=self.chat_template,
max_tokens_per_query=max_tokens_per_query,
max_tokens_per_doc=max_tokens_per_doc,
prompt_extras={ prompt_extras={
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
...@@ -283,8 +402,18 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -283,8 +402,18 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {}) **(ctx.tokenization_kwargs or {})
) )
max_tokens_per_query, max_tokens_per_doc = self._get_token_limits(
pooling_params=ctx.pooling_params
)
engine_inputs, pooling_params_list = self._pre_process( engine_inputs, pooling_params_list = self._pre_process(
ctx.prompts, tok_params, ctx.pooling_params, ctx.chat_template ctx.prompts,
tok_params,
ctx.pooling_params,
ctx.chat_template,
max_tokens_per_query=max_tokens_per_query,
max_tokens_per_doc=max_tokens_per_doc,
) )
ctx.pooling_params = pooling_params_list ctx.pooling_params = pooling_params_list
return engine_inputs return engine_inputs
...@@ -298,6 +427,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -298,6 +427,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
tok_params: TokenizeParams, tok_params: TokenizeParams,
pooling_params: PoolingParams | None, pooling_params: PoolingParams | None,
chat_template: str | None = None, chat_template: str | None = None,
max_tokens_per_query: int = 0,
max_tokens_per_doc: int = 0,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
) -> tuple[Sequence[EngineInput], list[PoolingParams]]: ) -> tuple[Sequence[EngineInput], list[PoolingParams]]:
# todo: support prompt_extras # todo: support prompt_extras
...@@ -320,6 +451,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -320,6 +451,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
data_2=d, data_2=d,
encode_kwargs=tok_params.get_encode_kwargs(), encode_kwargs=tok_params.get_encode_kwargs(),
chat_template=chat_template, chat_template=chat_template,
max_tokens_per_query=max_tokens_per_query,
max_tokens_per_doc=max_tokens_per_doc,
) )
if token_type_ids := engine_prompt.pop("token_type_ids", None): if token_type_ids := engine_prompt.pop("token_type_ids", None):
...@@ -342,6 +475,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -342,6 +475,8 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
data_2: ScoreData, data_2: ScoreData,
encode_kwargs: dict[str, Any], encode_kwargs: dict[str, Any],
chat_template: str | None = None, chat_template: str | None = None,
max_tokens_per_query: int = 0,
max_tokens_per_doc: int = 0,
): ):
model_config = self.model_config model_config = self.model_config
tokenizer = self.tokenizer tokenizer = self.tokenizer
...@@ -352,25 +487,61 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -352,25 +487,61 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
model_config, model_config,
) )
# Apply truncation before defining closures
if max_tokens_per_query > 0 and isinstance(prompt_1, str):
prompt_1 = truncate_text_to_tokens(
prompt_1, tokenizer, max_tokens_per_query
)
if max_tokens_per_doc > 0 and isinstance(prompt_2, str):
prompt_2 = truncate_text_to_tokens(prompt_2, tokenizer, max_tokens_per_doc)
def default_tokenizer_encode(): def default_tokenizer_encode():
local_kwargs = encode_kwargs.copy()
if self.supports_score_template: if self.supports_score_template:
assert self.model is not None assert self.model is not None
full_prompt = self.model.get_score_template(prompt_1, prompt_2) full_prompt = self.model.get_score_template(prompt_1, prompt_2)
if full_prompt is None: if full_prompt is None:
raise ValueError("Get empty score template from model") raise ValueError("Get empty score template from model")
prompt_inputs = tokenizer(full_prompt, **encode_kwargs) prompt_inputs = tokenizer(full_prompt, **local_kwargs)
else: else:
if self.use_sep_token: if self.use_sep_token:
# cross_encoder models defaults to using separating token. # cross_encoder models defaults to using separating token.
if max_tokens_per_doc > 0 and isinstance(prompt_2, str):
query_tokens = tokenizer.encode(
prompt_1, add_special_tokens=False
)
num_special = get_num_special_tokens_for_pair(tokenizer)
doc_limit_max_length = (
len(query_tokens) + max_tokens_per_doc + num_special
)
existing_max_length = local_kwargs.get("max_length")
if existing_max_length is not None:
effective_max_length = min(
doc_limit_max_length, existing_max_length
)
else:
effective_max_length = doc_limit_max_length
local_kwargs["truncation"] = "only_second"
local_kwargs["max_length"] = effective_max_length
prompt_inputs = tokenizer( prompt_inputs = tokenizer(
text=prompt_1, text_pair=prompt_2, **encode_kwargs text=prompt_1, text_pair=prompt_2, **local_kwargs
) )
full_prompt = tokenizer.decode(prompt_inputs["input_ids"]) full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else: else:
# `llm as reranker` defaults to not using separating token. # `llm as reranker` defaults to not using separating token.
full_prompt = prompt_1 + prompt_2 if max_tokens_per_doc > 0 and isinstance(prompt_2, str):
prompt_inputs = tokenizer(text=full_prompt, **encode_kwargs) query_ids = tokenizer.encode(prompt_1, add_special_tokens=False)
doc_ids = tokenizer.encode(prompt_2, add_special_tokens=False)
doc_ids = doc_ids[:max_tokens_per_doc]
input_ids = query_ids + doc_ids
full_prompt = tokenizer.decode(input_ids)
prompt_inputs = {"input_ids": input_ids}
else:
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **local_kwargs)
return full_prompt, prompt_inputs return full_prompt, prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided. # FIXME: For now, we only apply a template when one is explicitly provided.
......
...@@ -20,6 +20,24 @@ from .typing import ScoreContentPartParam, ScoreInput ...@@ -20,6 +20,24 @@ from .typing import ScoreContentPartParam, ScoreInput
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_tokens_per_query: int = Field(
default=0,
description=(
"Maximum number of tokens per query. Queries longer than "
"this will be truncated to this length. 0 means no "
"query-level truncation is applied."
),
)
max_tokens_per_doc: int = Field(
default=0,
description=(
"Maximum number of tokens per document. Documents longer than "
"this will be truncated to this length. 0 means no "
"document-level truncation is applied (only truncate_prompt_tokens "
"applies to the combined query+document)."
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
...@@ -91,29 +109,11 @@ ScoreRequest: TypeAlias = ( ...@@ -91,29 +109,11 @@ ScoreRequest: TypeAlias = (
) )
class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): class RerankRequest(ScoreRequestMixin):
query: ScoreInput query: ScoreInput
documents: ScoreInput | list[ScoreInput] documents: ScoreInput | list[ScoreInput]
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self, task: PoolingTask = "classify"):
return PoolingParams(
task=task,
use_activation=self.use_activation,
)
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
......
...@@ -237,6 +237,7 @@ class ServingScores(PoolingServing): ...@@ -237,6 +237,7 @@ class ServingScores(PoolingServing):
await self._prepare_generators(query_ctx) await self._prepare_generators(query_ctx)
await self._collect_batch(query_ctx) await self._collect_batch(query_ctx)
ctx.query_final_res_batch = query_ctx.final_res_batch
async def _flash_late_interaction_encode_docs(self, ctx: ScoringServeContext): async def _flash_late_interaction_encode_docs(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None assert ctx.n_queries is not None
......
...@@ -25,6 +25,37 @@ from .typing import ( ...@@ -25,6 +25,37 @@ from .typing import (
) )
def get_num_special_tokens_for_pair(tokenizer) -> int:
"""Get number of special tokens added for a text pair encoding."""
method = getattr(tokenizer, "num_special_tokens_to_add", None)
if method is not None:
try:
return method(pair=True)
except TypeError:
pass
# Fallback: compute by tokenizing empty strings
empty_encoding = tokenizer("", text_pair="", add_special_tokens=True)
return len(empty_encoding["input_ids"])
def truncate_text_to_tokens(
text: str,
tokenizer,
max_tokens: int,
) -> str:
"""Truncate text to a maximum number of content tokens.
Uses offset_mapping to slice the original text at the exact character
boundary, avoiding lossy encode→decode round-trips that can shift
the token count by 1-3 tokens due to BPE merge boundary changes.
"""
encoding = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
if len(encoding["input_ids"]) <= max_tokens:
return text
char_end = encoding["offset_mapping"][max_tokens - 1][1]
return text[:char_end]
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor: def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
""" """
Compute ColBERT MaxSim score. Compute ColBERT MaxSim score.
......
...@@ -89,6 +89,9 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -89,6 +89,9 @@ class PoolingServeContext(Generic[PoolingRequestT]):
## for IOProcessorResponse ## for IOProcessorResponse
response: Any | None = None response: Any | None = None
## for flash-late-interaction
query_final_res_batch: list[PoolingRequestOutput] | None = None
@dataclass @dataclass
class OfflineInputsContext: class OfflineInputsContext:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment