Unverified Commit 53bd00d9 authored by Sundara Raman Ramachandran's avatar Sundara Raman Ramachandran Committed by GitHub
Browse files

[Generative Score API] Multi-Item scoring with custom attention mask. (#10979)

parent e22b13c5
......@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
import logging
import os
from dataclasses import dataclass
from enum import Enum, auto
......@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import logging
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
logger = logging.getLogger(__name__)
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
......@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
CROSS_ATTENTION = auto()
@dataclass
class MultiItemScoringParams:
"""Parameters for multi-item scoring in attention computation.
Used when processing sequences with multiple items separated by delimiters,
where each item needs specific attention patterns that respect item boundaries.
Attributes:
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
The tensor size is equal to the batch size.
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
starting from 0 (delimiter) for each item. For batch size > 1,
sequences are concatenated with zero padding to ensure same length.
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
batch_size > 1 case. Defines the padded length for each sequence.
max_item_len_ptr: A uint16 tensor containing the max token length of all items
for each prompt in the batch.
"""
prefix_len_ptr: Optional[torch.Tensor] = None
token_pos_in_items_ptr: Optional[torch.Tensor] = None
token_pos_in_items_len: int = 0
max_item_len_ptr: Optional[torch.Tensor] = None
def is_enabled(self) -> bool:
"""Check if multi-item scoring is enabled."""
return self.prefix_len_ptr is not None
@dataclass
class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
......@@ -68,6 +99,7 @@ class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool
multi_item_params: Optional[MultiItemScoringParams] = None
# Reuse this workspace buffer across all flashinfer wrappers
......@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
):
super().__init__()
# Store multi-item scoring delimiter for efficient access
self.multi_item_scoring_delimiter = (
model_runner.server_args.multi_item_scoring_delimiter
)
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
......@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
self.draft_extend_cuda_graph_metadata = {} # For draft extend
def _process_multi_item_scoring(
self, forward_batch: ForwardBatch
) -> MultiItemScoringParams:
"""Process multi-item scoring tensors for FlashInfer attention.
This method handles sequences containing multiple "items" separated by delimiter tokens,
where each item needs specific attention patterns that respect item boundaries.
The method produces four key tensors for FlashInfer:
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
- token_pos_in_items_len: padding length for batch processing
- max_item_len_ptr: uint16 tensor with max item length for each prompt
Args:
forward_batch: The forward batch containing input sequences and delimiter info
Returns:
MultiItemScoringParams: The processed multi-item scoring parameters
Examples:
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
Case 1: Single sequence
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
- prefix_len_ptr: [7] (query length before first delimiter)
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
- token_pos_in_items_len: 7 (actual length)
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
Case 2: Batch processing (batch_size=2)
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
After padding both to length 10:
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
- token_pos_in_items_len: 10 (padded length for batch processing)
- max_item_len_ptr: [2, 3] (max lengths per sequence)
"""
delimiter = self.multi_item_scoring_delimiter
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
return MultiItemScoringParams()
delimiter_mask = forward_batch.input_ids == delimiter
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
prefix_len_ptr, token_pos_in_items_ptr = [], []
token_pos_in_items_len = 0
# If no extend_seq_lens, treat whole batch as one sequence
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
extend_seq_lens = [forward_batch.input_ids.size(0)]
seq_start = 0
for i, seq_len in enumerate(extend_seq_lens):
seq_end = seq_start + seq_len
mask = delimiter_mask[seq_start:seq_end]
pos = forward_batch.positions[seq_start:seq_end]
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
if len(delimiter_indices) > 0:
first_delim = delimiter_indices[0]
# Prefix length: store as scalar
prefix_len = first_delim + (
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
)
prefix_len_ptr.append(
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
)
# Compute relative positions within items after delimiters
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
token_pos = (diff - pos[first_delim]).to(torch.uint16)
token_pos_in_items_ptr.append(token_pos)
# Update forward_batch positions in-place
pos[first_delim:] = diff - 1
forward_batch.positions[seq_start:seq_end] = pos
seq_start = seq_end
# Pad token_pos_in_items_ptr for batch processing
if token_pos_in_items_ptr:
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
device = forward_batch.input_ids.device
token_pos_in_items_ptr = [
torch.cat(
[
t,
torch.zeros(
token_pos_in_items_len - t.numel(),
dtype=torch.uint16,
device=device,
),
]
)
for t in token_pos_in_items_ptr
]
if not prefix_len_ptr or not token_pos_in_items_ptr:
return MultiItemScoringParams()
# Build final params
device = forward_batch.input_ids.device
return MultiItemScoringParams(
prefix_len_ptr=torch.tensor(
prefix_len_ptr, dtype=torch.uint32, device=device
),
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
max_item_len_ptr=torch.stack(
[
t.to(torch.int32).max().to(torch.uint16)
for t in token_pos_in_items_ptr
],
dim=0,
),
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
......@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
else:
prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal:
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
# 2. Paged wrapper provides better control over attention masking needed
# for respecting item boundaries in multi-item sequences
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
use_ragged = False
extend_no_prefix = False
else:
use_ragged = not self.enable_deterministic
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
# Process multi-item scoring in attention backend instead of ForwardBatch
multi_item_params = MultiItemScoringParams()
if self.multi_item_scoring_delimiter is not None:
# Use new backend-specific implementation
multi_item_params = self._process_multi_item_scoring(forward_batch)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
......@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
fixed_split_size=self.prefill_split_tile_size,
multi_item_params=multi_item_params,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
self.prefill_wrappers_paged,
use_ragged,
extend_no_prefix,
multi_item_params,
)
def init_cuda_graph_state(
......@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
# - Multi-item sequences need full attention to properly handle delimiter tokens
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
# provide more precise attention control than simple sliding windows
# - Item-aware masking takes precedence over window-based masking
window_left=(
layer.sliding_window_size
if not (
self.forward_metadata.multi_item_params
and self.forward_metadata.multi_item_params.is_enabled()
)
else -1
),
logits_soft_cap=logits_soft_cap,
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
k_scale=layer.k_scale_float,
......@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
......@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
spec_info,
fixed_split_size=fixed_split_size,
multi_item_params=multi_item_params,
)
def update_sliding_window(
......@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
multi_item_params=multi_item_params,
)
def update_cross_attention(
......@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
multi_item_params=multi_item_params,
)
def call_begin_forward(
......@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[SpecInput],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
bs = len(seq_lens)
if spec_info is None:
......@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
)
# cached part
# Conditionally set multi-item parameters
if multi_item_params is not None and multi_item_params.is_enabled():
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
use_custom_mask = None
prefix_len_ptr = multi_item_params.prefix_len_ptr
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
max_item_len_ptr = multi_item_params.max_item_len_ptr
else:
# No multi-item scoring - use standard parameters
use_custom_mask = custom_mask
prefix_len_ptr = None
token_pos_in_items_ptr = None
token_pos_in_items_len = 0
max_item_len_ptr = None
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
......@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
1,
q_data_type=self.q_data_type,
kv_data_type=self.data_type,
custom_mask=custom_mask,
custom_mask=use_custom_mask,
non_blocking=True,
fixed_split_size=fixed_split_size,
prefix_len_ptr=prefix_len_ptr,
token_pos_in_items_ptr=token_pos_in_items_ptr,
token_pos_in_items_len=token_pos_in_items_len,
max_item_len_ptr=max_item_len_ptr,
)
......
......@@ -60,7 +60,8 @@ _is_npu = is_npu()
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
next_token_logits: Optional[torch.Tensor]
# Used by speculative decoding (EAGLE)
# The last hidden layers
hidden_states: Optional[torch.Tensor] = None
......@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: Optional[List] = None
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
None
)
input_token_ids_logprobs_idx: Optional[List] = None
......@@ -127,6 +131,9 @@ class LogitsMetadata:
# for padding
padded_static_len: int = -1
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
......@@ -169,6 +176,7 @@ class LogitsMetadata:
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
is_prefill_only=forward_batch.is_prefill_only,
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
......@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
"debug_tensor_dump_output_folder", None
)
def compute_logprobs_for_multi_item_scoring(
self,
input_ids,
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
delimiter_token: int,
):
"""
Compute logprobs for multi-item scoring using delimiter-based token extraction.
This method is designed for scenarios where you want to score multiple items/candidates
against a single query by combining them into one sequence separated by delimiters.
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
Scoring positions: Extracts logprobs at positions before each <delimiter>
Args:
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
hidden_states (torch.Tensor): Hidden states from the model.
Shape: [sequence_length, hidden_dim].
lm_head (VocabParallelEmbedding): Language model head for computing logits.
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
and token ID specifications for logprob extraction.
delimiter_token (int): Token ID used as delimiter between query and items.
Returns:
LogitsProcessorOutput: Contains:
- next_token_logits: None (not needed for scoring-only requests)
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
"""
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
0
] - 1
# Extract hidden states at delimiter positions for multi-item scoring
sliced_hidden = hidden_states[multi_item_indices]
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
# Initialize return values
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
input_top_logprobs_val = None
input_top_logprobs_idx = None
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
if (
logits_metadata.token_ids_logprobs
or logits_metadata.extend_return_top_logprob
):
logits_metadata.extend_logprob_pruned_lens_cpu = []
if logits_metadata.extend_seq_lens_cpu is not None:
# Multi-request batch: count delimiters per request
input_pt = 0
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
delimiter_count = (req_input_ids == delimiter_token).sum().item()
logits_metadata.extend_logprob_pruned_lens_cpu.append(
delimiter_count
)
input_pt += req_seq_len
else:
# Single request case: one request gets all delimiters
total_delimiters = (input_ids == delimiter_token).sum().item()
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
# Get the logprobs of specified token ids
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(
sliced_logprobs, logits_metadata, delay_cpu_copy=True
)
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
# For input_token_logprobs, use delimiter token logprobs
input_token_logprobs = sliced_logprobs[:, delimiter_token]
return LogitsProcessorOutput(
next_token_logits=None, # Multi-item scoring doesn't need next token logits
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
def forward(
self,
input_ids,
......@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get(
"multi_item_scoring_delimiter"
)
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
)
# Get the last hidden states and last logits for the next token prediction
if (
logits_metadata.forward_mode.is_decode_or_idle()
......@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
all_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
delay_cpu_copy: bool = False,
):
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
......@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
input_token_ids_logprobs_idx.append([])
continue
input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
)
position_logprobs = all_logprobs[
pt : pt + pruned_len, token_ids
] # Shape: [pruned_len, num_tokens]
if delay_cpu_copy:
# Keep as tensor to delay GPU-to-CPU transfer
input_token_ids_logprobs_val.append(position_logprobs)
else:
# Convert to list immediately (default behavior)
input_token_ids_logprobs_val.append(position_logprobs.tolist())
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len
......
......@@ -114,6 +114,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
"multi_item_scoring_delimiter",
]
# Put some global args for easy access
......@@ -666,9 +667,11 @@ class Req:
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled
return (
self.sampling_params.max_new_tokens == 0
and global_server_args_dict["speculative_algorithm"] is None
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
spec_alg = global_server_args_dict["speculative_algorithm"]
return self.sampling_params.max_new_tokens == 0 and (
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
)
def add_latency(self, stage: RequestStage):
......
......@@ -104,7 +104,10 @@ class SchedulerOutputProcessorMixin:
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
num_input_logprobs = self._calculate_num_input_logprobs(
req, extend_input_len, extend_logprob_start_len
)
if req.return_logprob:
self.add_logprob_return_values(
......@@ -159,8 +162,8 @@ class SchedulerOutputProcessorMixin:
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = (
extend_input_len - extend_logprob_start_len
num_input_logprobs = self._calculate_num_input_logprobs(
req, extend_input_len, extend_logprob_start_len
)
if req.return_logprob:
self.add_input_logprob_return_values(
......@@ -303,6 +306,153 @@ class SchedulerOutputProcessorMixin:
):
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
def _process_input_token_logprobs(
self, req: Req, input_token_logprobs: List
) -> None:
"""Process input token logprobs values and indices."""
is_multi_item_scoring = self._is_multi_item_scoring(req)
# Process logprob values - handle multi-item scoring vs regular requests
if is_multi_item_scoring:
# Multi-item scoring: use all logprobs as-is
req.input_token_logprobs_val = input_token_logprobs
else:
# Regular request: add None at start, remove last (sampling token)
req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
# Process logprob indices based on scoring type
if is_multi_item_scoring:
# Multi-item scoring: only include delimiter token positions
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
input_token_logprobs_idx = [
token_id
for token_id in relevant_tokens
if token_id == self.server_args.multi_item_scoring_delimiter
]
else:
# Regular request: include all tokens from logprob_start_len onwards
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
# Clip padded hash values from image tokens to prevent detokenization errors
req.input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_logprobs_idx
]
def _process_input_top_logprobs(self, req: Req) -> None:
"""Process input top logprobs."""
if req.top_logprobs_num <= 0:
return
is_multi_item_scoring = self._is_multi_item_scoring(req)
# Initialize arrays - multi-item scoring starts empty, others start with None
req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
# Extend arrays with temp values
for val, idx in zip(
req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
# Remove last token (sampling token) for non multi-item scoring requests
if not is_multi_item_scoring:
req.input_top_logprobs_val.pop()
req.input_top_logprobs_idx.pop()
# Clean up temp storage
req.temp_input_top_logprobs_idx = None
req.temp_input_top_logprobs_val = None
def _process_input_token_ids_logprobs(self, req: Req) -> None:
"""Process input token IDs logprobs."""
if req.token_ids_logprob is None:
return
is_multi_item_scoring = self._is_multi_item_scoring(req)
# Initialize arrays - multi-item scoring starts empty, others start with None
req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
# Process temp values - convert tensors to lists and extend arrays
for val, idx in zip(
req.temp_input_token_ids_logprobs_val,
req.temp_input_token_ids_logprobs_idx,
strict=True,
):
val_list = val.tolist() if isinstance(val, torch.Tensor) else val
req.input_token_ids_logprobs_val.extend(
val_list if isinstance(val_list, list) else [val_list]
)
req.input_token_ids_logprobs_idx.extend(idx)
# Remove last token (sampling token) for non multi-item scoring requests
if not is_multi_item_scoring:
req.input_token_ids_logprobs_val.pop()
req.input_token_ids_logprobs_idx.pop()
# Clean up temp storage
req.temp_input_token_ids_logprobs_idx = None
req.temp_input_token_ids_logprobs_val = None
def _calculate_relevant_tokens_len(self, req: Req) -> int:
"""Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
For multi-item scoring, only delimiter positions have logprobs.
For regular requests, all positions from logprob_start_len onwards have logprobs.
"""
is_multi_item_scoring = self._is_multi_item_scoring(req)
if is_multi_item_scoring:
# Multi-item scoring: count delimiter tokens from logprob_start_len onwards
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
return sum(
1
for token_id in relevant_tokens
if token_id == self.server_args.multi_item_scoring_delimiter
)
else:
# Regular request: all tokens from logprob_start_len onwards
return len(req.origin_input_ids) - req.logprob_start_len
def _calculate_num_input_logprobs(
self, req: Req, extend_input_len: int, extend_logprob_start_len: int
) -> int:
"""Calculate the number of input logprobs based on whether multi-item scoring is enabled.
For multi-item scoring, only delimiter positions have logprobs.
For regular requests, all positions in the range have logprobs.
"""
is_multi_item_scoring = self._is_multi_item_scoring(req)
if is_multi_item_scoring:
# Multi-item scoring: count delimiter tokens in the relevant portion
relevant_tokens = req.origin_input_ids[
extend_logprob_start_len:extend_input_len
]
return sum(
1
for token_id in relevant_tokens
if token_id == self.server_args.multi_item_scoring_delimiter
)
else:
# Regular request: all tokens in the range
return extend_input_len - extend_logprob_start_len
def _is_multi_item_scoring(self, req: Req) -> bool:
"""Check if request uses multi-item scoring.
Multi-item scoring applies to prefill-only requests when a delimiter
token is configured. In this mode, only positions containing the
delimiter token receive logprobs.
"""
return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
def add_input_logprob_return_values(
self: Scheduler,
i: int,
......@@ -371,63 +521,14 @@ class SchedulerOutputProcessorMixin:
assert req.input_top_logprobs_val is None
assert req.input_top_logprobs_idx is None
# Compute input_token_logprobs_val
# Always pad the first one with None.
req.input_token_logprobs_val = [None]
req.input_token_logprobs_val.extend(input_token_logprobs)
# The last input logprob is for sampling, so just pop it out.
req.input_token_logprobs_val.pop()
# Process all input logprob types using helper functions
self._process_input_token_logprobs(req, input_token_logprobs)
self._process_input_top_logprobs(req)
# Compute input_token_logprobs_idx
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_logprobs_idx
]
req.input_token_logprobs_idx = input_token_logprobs_idx
if req.top_logprobs_num > 0:
req.input_top_logprobs_val = [None]
req.input_top_logprobs_idx = [None]
assert len(req.temp_input_token_ids_logprobs_val) == len(
req.temp_input_token_ids_logprobs_idx
)
for val, idx in zip(
req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
# Last token is a sample token.
req.input_top_logprobs_val.pop()
req.input_top_logprobs_idx.pop()
req.temp_input_top_logprobs_idx = None
req.temp_input_top_logprobs_val = None
if req.token_ids_logprob is not None:
req.input_token_ids_logprobs_val = [None]
req.input_token_ids_logprobs_idx = [None]
for val, idx in zip(
req.temp_input_token_ids_logprobs_val,
req.temp_input_token_ids_logprobs_idx,
strict=True,
):
req.input_token_ids_logprobs_val.extend(val)
req.input_token_ids_logprobs_idx.extend(idx)
# Last token is a sample token.
req.input_token_ids_logprobs_val.pop()
req.input_token_ids_logprobs_idx.pop()
req.temp_input_token_ids_logprobs_idx = None
req.temp_input_token_ids_logprobs_val = None
self._process_input_token_ids_logprobs(req)
if req.return_logprob:
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
relevant_tokens_len = self._calculate_relevant_tokens_len(req)
assert len(req.input_token_logprobs_val) == relevant_tokens_len
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
if req.top_logprobs_num > 0:
......
......@@ -182,6 +182,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if speculative_algorithm.is_none()
else server_args.speculative_num_draft_tokens
)
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
self.multi_item_delimiter_text = None
if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors")
......@@ -223,6 +225,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.processor = _processor
self.tokenizer = get_tokenizer_from_processor(self.processor)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self._initialize_multi_item_delimiter_text()
else:
self.mm_processor = self.processor = None
......@@ -235,6 +238,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self._initialize_multi_item_delimiter_text()
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
if (
server_args.enable_dynamic_batch_tokenizer
......@@ -1678,6 +1682,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
def _initialize_multi_item_delimiter_text(self):
"""Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
if (
hasattr(self.server_args, "multi_item_scoring_delimiter")
and self.server_args.multi_item_scoring_delimiter is not None
and self.tokenizer is not None
):
try:
self.multi_item_delimiter_text = self.tokenizer.decode(
[self.server_args.multi_item_scoring_delimiter],
skip_special_tokens=False,
)
except Exception as e:
logger.warning(
f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
)
self.multi_item_delimiter_text = None
def _build_multi_item_token_sequence(
self, query: List[int], items: List[List[int]], delimiter_token_id: int
) -> List[int]:
"""
Build a single token sequence for multi-item scoring.
Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
Args:
query: Query token IDs
items: List of item token ID sequences
delimiter_token_id: Token ID to use as delimiter
Returns:
Combined token sequence
"""
combined_sequence = query[:] # Start with query
for item in items:
combined_sequence.append(delimiter_token_id) # Add delimiter
combined_sequence.extend(item) # Add item tokens
# Add final delimiter after the last item for logprob extraction
combined_sequence.append(delimiter_token_id)
return combined_sequence
def _extract_logprobs_for_tokens(
self, logprobs_data: List, label_token_ids: List[int]
) -> Dict[int, float]:
"""
Extract logprobs for specified token IDs from logprobs data.
Args:
logprobs_data: List of (logprob, token_id, text) tuples
label_token_ids: Token IDs to extract logprobs for
Returns:
Dictionary mapping token_id to logprob
"""
logprobs = {}
if logprobs_data:
for logprob, token_id, _ in logprobs_data:
if token_id in label_token_ids:
logprobs[token_id] = logprob
return logprobs
def _convert_logprobs_to_scores(
self,
logprobs: Dict[int, float],
label_token_ids: List[int],
apply_softmax: bool,
) -> List[float]:
"""
Convert logprobs dictionary to ordered score list.
Args:
logprobs: Dictionary mapping token_id to logprob
label_token_ids: Token IDs in desired order
apply_softmax: Whether to apply softmax normalization
Returns:
List of scores in the same order as label_token_ids
"""
score_list = [
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
]
if apply_softmax:
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
else:
# Convert logprobs to probabilities if not using softmax
score_list = [
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
]
return score_list
def _process_multi_item_scoring_results(
self,
results: Any,
items: List,
label_token_ids: List[int],
apply_softmax: bool,
batch_request=None,
) -> List[List[float]]:
"""
Process results from multi-item scoring request.
Extracts logprobs at delimiter positions from input_token_ids_logprobs.
Args:
results: Results from generate_request
items: List of items being scored
label_token_ids: Token IDs to extract scores for
apply_softmax: Whether to apply softmax normalization
batch_request: The original batch request containing input sequence
Returns:
List of score lists, one for each item
"""
single_result = results[0] if isinstance(results, list) else results
# For multi-item scoring, logprobs are in input_token_ids_logprobs
input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
if not input_logprobs:
raise RuntimeError(
f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
"This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
)
scores = []
num_items = len(items) if isinstance(items, list) else 1
# Check if we have the expected number of logprobs
expected_logprobs_count = num_items + 1
if len(input_logprobs) != expected_logprobs_count:
raise RuntimeError(
f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
f"with {num_items} items, but got {len(input_logprobs)}. "
f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
)
# Skip the first delimiter (between query and first item) and process remaining delimiter positions
# We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
start_idx = 1 if len(input_logprobs) > 1 else 0
# Process logprobs for each item position (excluding first delimiter)
for item_idx in range(num_items):
logprob_idx = start_idx + item_idx
item_logprobs_data = input_logprobs[logprob_idx]
logprobs = self._extract_logprobs_for_tokens(
item_logprobs_data, label_token_ids
)
score_list = self._convert_logprobs_to_scores(
logprobs, label_token_ids, apply_softmax
)
scores.append(score_list)
return scores
def _process_single_item_scoring_results(
self, results: Any, label_token_ids: List[int], apply_softmax: bool
) -> List[List[float]]:
"""
Process results from single-item scoring request.
Single-item scoring results are stored in output_token_ids_logprobs.
Args:
results: Results from generate_request
label_token_ids: Token IDs to extract scores for
apply_softmax: Whether to apply softmax normalization
Returns:
List of score lists, one for each result
"""
scores = []
for result in results:
# For single-item scoring, logprobs are in output_token_ids_logprobs
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
if not output_logprobs or len(output_logprobs) == 0:
raise RuntimeError(
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
)
# Extract logprobs for the first (and only) position
logprobs = self._extract_logprobs_for_tokens(
output_logprobs[0], label_token_ids
)
score_list = self._convert_logprobs_to_scores(
logprobs, label_token_ids, apply_softmax
)
scores.append(score_list)
return scores
async def score_request(
self,
query: Optional[Union[str, List[int]]] = None,
......@@ -1688,7 +1887,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
request: Optional[Any] = None,
) -> List[List[float]]:
"""
See Engine.score() for more details.
Score the probability of specified token IDs appearing after the given (query + item) pair.
This method supports two scoring approaches:
1. Single-Item scoring (default): Process each query+item pair independently
2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
multiple items into a single sequence using delimiter for efficient processing.
Note: item_first parameter is ignored in multi-item scoring mode since it uses
a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
Multi-item scoring works with both text and pre-tokenized inputs:
- Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
- Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
Args:
query: The query text or pre-tokenized query token IDs
items: The item text(s) or pre-tokenized item token IDs
label_token_ids: List of token IDs to compute probabilities for
apply_softmax: Whether to normalize probabilities using softmax
item_first: If True, prepend items to query. Ignored for multi-item scoring.
request: Optional FastAPI request object
Returns:
List of lists containing probabilities for each item and each label token
"""
if label_token_ids is None:
raise ValueError("label_token_ids must be provided")
......@@ -1701,9 +1922,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
)
# Check if multi-item scoring is enabled by presence of delimiter
use_multi_item_scoring = (
self.server_args.multi_item_scoring_delimiter is not None
and self.multi_item_delimiter_text is not None
)
batch_request = GenerateReqInput(
token_ids_logprob=label_token_ids,
return_logprob=True,
# Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
logprob_start_len=0 if use_multi_item_scoring else -1,
stream=False,
sampling_params={"max_new_tokens": 0},
)
......@@ -1715,12 +1944,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
):
# Both query and items are text
items_list = [items] if isinstance(items, str) else items
if item_first:
prompts = [f"{item}{query}" for item in items_list]
else:
prompts = [f"{query}{item}" for item in items_list]
batch_request.text = prompts
if use_multi_item_scoring:
# Multi-item scoring: create single prompt with delimiter text
# Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
# (item_first is ignored for multi-item scoring)
delimiter = self.multi_item_delimiter_text
combined_items = delimiter.join(items_list)
# Add final delimiter after the last item for logprob extraction
single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
batch_request.text = [single_prompt]
else:
# Single-item scoring: create separate prompts for each item
if item_first:
prompts = [f"{item}{query}" for item in items_list]
else:
prompts = [f"{query}{item}" for item in items_list]
batch_request.text = prompts
elif (
isinstance(query, list)
......@@ -1729,61 +1969,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
and isinstance(items[0], list)
):
# Both query and items are token IDs
if item_first:
input_ids_list = [item + query for item in items]
if use_multi_item_scoring:
# Multi-item scoring: concatenate with delimiter token ID
# Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
delimiter_token_id = self.server_args.multi_item_scoring_delimiter
combined_input_ids = self._build_multi_item_token_sequence(
query, items, delimiter_token_id
)
batch_request.input_ids = [combined_input_ids]
else:
input_ids_list = [query + item for item in items]
batch_request.input_ids = input_ids_list
# Single-item scoring: process each item separately
if item_first:
input_ids_list = [item + query for item in items]
else:
input_ids_list = [query + item for item in items]
batch_request.input_ids = input_ids_list
else:
raise ValueError(
"Invalid combination of query/items types for score_request."
)
results = await self.generate_request(batch_request, request).__anext__()
scores = []
for result in results:
# Get logprobs for each token
logprobs = {}
# For scoring requests, we read from output_token_ids_logprobs since we want
# the logprobs for specific tokens mentioned in the label_token_ids at
# the next position after the last token in the prompt
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
# Check if output_logprobs is properly populated
if (
output_logprobs is None
or not output_logprobs
or len(output_logprobs) == 0
):
raise RuntimeError(
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
"This indicates token_ids_logprobs were not computed properly for the scoring request."
)
for logprob, token_id, _ in output_logprobs[0]:
if token_id in label_token_ids:
logprobs[token_id] = logprob
# Get scores in order of label_token_ids
score_list = [
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
]
# Apply softmax to logprobs if needed
if apply_softmax:
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
else:
# Convert logprobs to probabilities if not using softmax
score_list = [
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
]
scores.append(score_list)
return scores
if use_multi_item_scoring:
# Multi-item scoring: extract scores from input_token_ids_logprobs
return self._process_multi_item_scoring_results(
results, items, label_token_ids, apply_softmax, batch_request
)
else:
# Single-item scoring: process each result separately
return self._process_single_item_scoring_results(
results, label_token_ids, apply_softmax
)
async def watch_load_thread(self):
# Only for dp_controller when dp_size > 1
......
......@@ -266,10 +266,16 @@ class TpModelWorker:
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
batch_result.next_token_ids = torch.zeros_like(
model_worker_batch.input_ids, dtype=torch.long
# The size should match the batch size (number of sequences), not total tokens
batch_result.next_token_ids = torch.zeros(
len(model_worker_batch.seq_lens),
dtype=torch.long,
device=model_worker_batch.input_ids.device,
)
if model_worker_batch.return_logprob:
if (
model_worker_batch.return_logprob
and logits_output.next_token_logits is not None
):
# NOTE: Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
......
......@@ -278,6 +278,9 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
# Speculative decoding
spec_info: Optional[SpecInput] = None
spec_algorithm: SpeculativeAlgorithm = None
......@@ -325,6 +328,7 @@ class ForwardBatch:
is_extend_in_batch=batch.is_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
is_prefill_only=batch.is_prefill_only,
lora_ids=batch.lora_ids,
sampling_info=batch.sampling_info,
req_to_token_pool=model_runner.req_to_token_pool,
......
......@@ -382,6 +382,12 @@ class ServerArgs:
offload_prefetch_step: int = 1
offload_mode: str = "cpu"
# Scoring configuration
# Delimiter token ID used to combine Query and Items into a single sequence for multi-item scoring.
# Format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
# This enables efficient batch processing of multiple items against a single query.
multi_item_scoring_delimiter: Optional[Union[int]] = None
# Optimization/debug options
disable_radix_cache: bool = False
cuda_graph_max_bs: Optional[int] = None
......@@ -2334,7 +2340,13 @@ class ServerArgs:
choices=["float32", "bfloat16"],
help="The data type of the SSM states in mamba cache.",
)
# Args for multi-item-scoring
parser.add_argument(
"--multi-item-scoring-delimiter",
type=int,
default=ServerArgs.multi_item_scoring_delimiter,
help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query.",
)
# Hierarchical cache
parser.add_argument(
"--enable-hierarchical-cache",
......@@ -3004,6 +3016,17 @@ class ServerArgs:
"lof",
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
# Check multi-item scoring
if self.multi_item_scoring_delimiter is not None:
assert self.disable_radix_cache, (
"Multi-item scoring requires radix cache to be disabled. "
"Please set --disable-radix-cache when using --multi-item-scoring-delimiter."
)
assert self.chunked_prefill_size == -1, (
"Multi-item scoring requires chunked prefill to be disabled. "
"Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter."
)
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
......
......@@ -667,6 +667,7 @@ class TboForwardBatchPreparer:
"can_run_dp_cuda_graph",
"dp_padding_mode",
"global_forward_mode",
"is_prefill_only",
"spec_algorithm",
"capture_hidden_mode",
"padded_static_len",
......
......@@ -295,6 +295,296 @@ class TestScoreAPI(CustomTestCase):
)
self.assertFalse(request.stream, "Scoring requests should not stream")
def test_multi_item_scoring_basic(self):
"""Test basic multi-item scoring functionality."""
# Test with a simple query and items
query = "What is the capital of California? Answer Yes or No for each of the following options:"
items = ["Sacramento", "San Jose", "San Francisco"]
label_token_ids = [9454, 2753] # "Yes" and "No" tokens
# Get scores using SGLang
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
# Verify we get the expected number of scores
self.assertEqual(len(scores), len(items), "Should get one score list per item")
# Verify each score list has the correct length
for i, score_list in enumerate(scores):
self.assertEqual(
len(score_list),
len(label_token_ids),
f"Item {i} should have {len(label_token_ids)} scores",
)
# Verify scores are probabilities (sum to 1)
self.assertAlmostEqual(
sum(score_list),
1.0,
places=6,
msg=f"Scores for item {i} should sum to 1",
)
# Verify all scores are non-negative
for j, score in enumerate(score_list):
self.assertGreaterEqual(
score, 0, f"Score {j} for item {i} should be non-negative"
)
def test_multi_item_scoring_consistency(self):
"""Test that multi-item scoring gives consistent results."""
query = "Choose the best option:"
items = ["Option A", "Option B", "Option C"]
label_token_ids = [1, 2, 3]
# Run the same test multiple times
scores1 = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
scores2 = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
# Results should be identical (deterministic)
self.assertEqual(len(scores1), len(scores2), "Should get same number of items")
for i, (s1, s2) in enumerate(zip(scores1, scores2)):
self.assertEqual(
len(s1), len(s2), f"Item {i} should have same number of scores"
)
for j, (score1, score2) in enumerate(zip(s1, s2)):
self.assertAlmostEqual(
score1,
score2,
places=6,
msg=f"Score {j} for item {i} should be identical",
)
def test_multi_item_scoring_different_sizes(self):
"""Test multi-item scoring with different numbers of items."""
query = "Rate each option:"
label_token_ids = [1, 2, 3, 4, 5]
# Test with different numbers of items
test_cases = [
["Single item"],
["Item 1", "Item 2"],
["A", "B", "C", "D"],
["X", "Y", "Z", "W", "V", "U"],
]
for items in test_cases:
with self.subTest(items=items):
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(
len(scores), len(items), f"Should get {len(items)} score lists"
)
for i, score_list in enumerate(scores):
self.assertEqual(
len(score_list),
len(label_token_ids),
f"Item {i} should have {len(label_token_ids)} scores",
)
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
def test_multi_item_scoring_empty_items(self):
"""Test multi-item scoring with empty items list."""
query = "Test query"
items = []
label_token_ids = [1, 2]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(len(scores), 0, "Should return empty list for empty items")
def test_multi_item_scoring_single_item(self):
"""Test multi-item scoring with single item (should work like regular scoring)."""
query = "Complete this sentence: The capital of France is"
items = ["Paris"]
label_token_ids = [1, 2, 3]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(len(scores), 1, "Should get one score list")
self.assertEqual(
len(scores[0]), len(label_token_ids), "Should have correct number of scores"
)
self.assertAlmostEqual(sum(scores[0]), 1.0, places=6)
def test_multi_item_scoring_different_queries(self):
"""Test multi-item scoring with different types of queries."""
items = ["Yes", "No"]
label_token_ids = [1, 2]
test_queries = [
"Is this true?",
"Choose the correct answer:",
"What is the best option?",
"Select all that apply:",
"", # Empty query
]
for query in test_queries:
with self.subTest(query=query):
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(
len(scores),
len(items),
f"Should get {len(items)} score lists for query: '{query}'",
)
for i, score_list in enumerate(scores):
self.assertEqual(len(score_list), len(label_token_ids))
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
def test_multi_item_scoring_different_label_tokens(self):
"""Test multi-item scoring with different label token sets."""
query = "Choose the best option:"
items = ["Option A", "Option B"]
test_label_tokens = [
[1, 2], # Two tokens
[1, 2, 3, 4], # Four tokens
[1], # Single token
[1, 2, 3, 4, 5, 6, 7, 8], # Many tokens
]
for label_token_ids in test_label_tokens:
with self.subTest(label_tokens=label_token_ids):
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(len(scores), len(items))
for i, score_list in enumerate(scores):
self.assertEqual(
len(score_list),
len(label_token_ids),
f"Item {i} should have {len(label_token_ids)} scores",
)
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
def test_multi_item_scoring_without_softmax(self):
"""Test multi-item scoring without softmax normalization."""
query = "Rate each option:"
items = ["Good", "Bad", "Neutral"]
label_token_ids = [1, 2, 3]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=False, # No softmax
)
self.assertEqual(len(scores), len(items))
for i, score_list in enumerate(scores):
self.assertEqual(len(score_list), len(label_token_ids))
# Without softmax, scores don't need to sum to 1
# But they should still be valid logits/probabilities
for j, score in enumerate(score_list):
self.assertIsInstance(
score, (int, float), f"Score {j} for item {i} should be numeric"
)
def test_multi_item_scoring_large_batch(self):
"""Test multi-item scoring with a large number of items."""
query = "Classify each item:"
items = [f"Item {i}" for i in range(20)] # 20 items
label_token_ids = [1, 2, 3]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(len(scores), len(items), "Should handle large batches")
for i, score_list in enumerate(scores):
self.assertEqual(len(score_list), len(label_token_ids))
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
def test_multi_item_scoring_unicode(self):
"""Test multi-item scoring with unicode characters."""
query = "选择最佳选项:"
items = ["选项A", "选项B", "选项C"]
label_token_ids = [1, 2, 3]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
self.assertEqual(len(scores), len(items))
for i, score_list in enumerate(scores):
self.assertEqual(len(score_list), len(label_token_ids))
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
def test_multi_item_scoring_error_handling(self):
"""Test multi-item scoring error handling."""
query = "Test query"
items = ["Item 1", "Item 2"]
label_token_ids = [1, 2]
# Test with invalid label_token_ids
with self.assertRaises((ValueError, TypeError)):
self.engine.score(
query=query,
items=items,
label_token_ids="invalid", # Should be list of ints
apply_softmax=True,
)
# Test with None items
with self.assertRaises((ValueError, TypeError)):
self.engine.score(
query=query,
items=None,
label_token_ids=label_token_ids,
apply_softmax=True,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment