Unverified Commit a360511d authored by Sundara Raman Ramachandran's avatar Sundara Raman Ramachandran Committed by GitHub
Browse files

[Generative Score API] Scoring(Prefill-only) optimizations. (#9748)

parent 94d0f656
......@@ -72,7 +72,10 @@ class LogitsProcessorOutput:
next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: Optional[List] = None
# Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
next_token_token_ids_logprobs_val: Optional[
List[Union[List[float], torch.Tensor]]
] = None
next_token_token_ids_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
......
import logging
from typing import List
from typing import List, Tuple
import torch
import torch.distributed as dist
......@@ -39,6 +39,25 @@ class Sampler(nn.Module):
if is_dp_attention_enabled():
self.tp_sync_group = get_attention_tp_group().device_group
def _preprocess_logits(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
) -> torch.Tensor:
"""Apply custom logit processors and handle NaN detection."""
# Apply the custom logit processors if registered in the sampling info
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(logits, sampling_info)
# Detect and handle NaN values in logits
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
return logits
def forward(
self,
logits_output: LogitsProcessorOutput,
......@@ -61,17 +80,8 @@ class Sampler(nn.Module):
"""
logits = logits_output.next_token_logits
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(logits, sampling_info)
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
# Preprocess logits (custom processors and NaN handling)
logits = self._preprocess_logits(logits, sampling_info)
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
......@@ -165,6 +175,54 @@ class Sampler(nn.Module):
return batch_next_token_ids
def compute_logprobs_only(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
) -> None:
"""
Compute logprobs for requested token IDs without performing sampling.
Optimized for prefill-only scoring requests that need token probabilities
but don't require next token generation.
"""
if logits_output.next_token_logits is None:
logger.warning("No logits available for logprob computation")
return
# Check if any requests actually need logprobs computation
needs_token_ids_logprobs = any(
token_ids is not None and len(token_ids) > 0
for token_ids in token_ids_logprobs
)
needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
if not (needs_token_ids_logprobs or needs_top_logprobs):
return
# Preprocess logits (custom processors and NaN handling)
logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
# Compute logprobs
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
# Handle top logprobs if requested
if needs_top_logprobs:
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums)
# Handle token_ids logprobs if requested
if needs_token_ids_logprobs:
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
......@@ -234,10 +292,95 @@ def get_top_logprobs(
)
def get_token_ids_logprobs(
def get_token_ids_logprobs_batch_optimized(
logprobs: torch.Tensor,
token_ids_logprobs: List[List[int]],
):
) -> Tuple[List, List]:
"""
Vectorized batch processing for token ID logprobs extraction.
Uses a single GPU kernel call for the entire batch instead of multiple
separate calls, significantly improving performance for large batches.
Args:
logprobs: Log probabilities tensor [batch_size, vocab_size]
token_ids_logprobs: List of token IDs to extract logprobs for
Example:
# Input: batch_size=3, vocab_size=5
logprobs = torch.tensor([
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
])
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
# Output:
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
# indices = [[1, 3], [2], [0, 2, 4]]
"""
batch_size = len(token_ids_logprobs)
device = logprobs.device
# Step 1: Calculate lengths for each request, treating None as empty list
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
token_lengths = torch.tensor(
[len(token_ids or []) for token_ids in token_ids_logprobs], device=device
)
total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
# Handle edge case where no tokens are requested
if total_tokens == 0:
return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
[] for _ in token_ids_logprobs
]
# Step 2: Build flattened indices using torch operations
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
row_indices = torch.repeat_interleave(
torch.arange(batch_size, device=device), token_lengths
)
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
col_indices = torch.tensor(
[
token_id
for token_ids in token_ids_logprobs
for token_id in (token_ids or [])
],
device=device,
dtype=torch.long,
)
# Step 3: Single vectorized gather operation
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
gathered_logprobs = logprobs[row_indices, col_indices]
# Step 4: Split results back per request using torch operations
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
split_logprobs = torch.split_with_sizes(
gathered_logprobs, token_lengths.tolist(), dim=0
)
# Step 5: Format output to match expected return structure
# Example: Convert split tensors back to list format with proper empty handling
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
# i=1: [2] -> append split_logprobs[1] and [2]
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None and len(token_ids) > 0:
output_token_ids_logprobs_val.append(split_logprobs[i])
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append(logprobs.new_empty(0))
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
......
......@@ -561,7 +561,10 @@ class Req:
# shape: (bs, k)
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
self.output_token_ids_logprobs_val: List[
Union[List[float], torch.Tensor]
] = []
self.output_token_ids_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
......@@ -619,6 +622,11 @@ class Req:
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
@property
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0
def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
......@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all(
req.sampling_params.max_new_tokens == 0 for req in reqs
),
is_prefill_only=all(req.is_prefill_only for req in reqs),
chunked_req=chunked_req,
)
......@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.is_retracted = False
# Compute the relative logprob_start_len in an extend batch
#
# Key variables:
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
# - extend_input_len: Number of tokens that need to be processed in this extend batch
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
# Optimization for prefill-only requests: When we only need logprobs at
# positions beyond the input sequence (to score next-token likelihood), skip all
# input logprob computation during prefill since no generation will occur.
if self.is_prefill_only and req.logprob_start_len == len(
req.origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req.extend_logprob_start_len = req.extend_input_len
else:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
# logprob_start_len is before the current extend batch, so start from beginning
req.extend_logprob_start_len = 0
if self.return_logprob:
......@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
is_prefill_only=self.is_prefill_only,
)
def copy(self):
......@@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
# Overlap event
launch_done: Optional[threading.Event] = None
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
@triton.jit
def write_req_to_token_pool_triton(
......
......@@ -1261,11 +1261,19 @@ class Scheduler(
# Copy more attributes
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
# to skip input logprob computation entirely
if req.is_prefill_only:
req.logprob_start_len = len(req.origin_input_ids)
else:
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids):
if not req.is_prefill_only and req.logprob_start_len >= len(
req.origin_input_ids
):
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
req.logprob_start_len = len(req.origin_input_ids) - 1
req.set_finish_with_abort(error_msg)
......
......@@ -5,6 +5,8 @@ import threading
import time
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
......@@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin:
# Check finish conditions
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
......@@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
if req.return_logprob:
self.add_logprob_return_values(
i,
......@@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin:
output: LogitsProcessorOutput,
):
"""Attach logprobs to the return values."""
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
if output.next_token_logprobs is not None:
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
# Only add input logprobs if there are input tokens to process
# Note: For prefill-only requests with default logprob_start_len, this will be 0,
# meaning we only compute output logprobs (which is the intended behavior)
if num_input_logprobs > 0:
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
else:
self._initialize_empty_logprob_containers(req)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i]
)
if (
req.token_ids_logprob is not None
and output.next_token_token_ids_logprobs_val is not None
):
# Convert GPU tensor to list if needed
logprobs_val = output.next_token_token_ids_logprobs_val[i]
if isinstance(logprobs_val, torch.Tensor):
logprobs_val = logprobs_val.tolist()
req.output_token_ids_logprobs_val.append(logprobs_val)
req.output_token_ids_logprobs_idx.append(
output.next_token_token_ids_logprobs_idx[i]
)
return num_input_logprobs
def _initialize_empty_logprob_containers(self, req: Req) -> None:
"""
Initialize logprob fields to empty lists if unset.
This is needed for prefill-only requests where the normal initialization
flow might be bypassed, but downstream code expects these fields to be lists.
"""
if req.input_token_logprobs_val is None:
req.input_token_logprobs_val = []
if req.input_token_logprobs_idx is None:
req.input_token_logprobs_idx = []
if req.input_top_logprobs_val is None:
req.input_top_logprobs_val = []
if req.input_top_logprobs_idx is None:
req.input_top_logprobs_idx = []
if req.input_token_ids_logprobs_val is None:
req.input_token_ids_logprobs_val = []
if req.input_token_ids_logprobs_idx is None:
req.input_token_ids_logprobs_idx = []
def stream_output(
self: Scheduler,
reqs: List[Req],
......
......@@ -1778,11 +1778,15 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# the next position after the last token in the prompt
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
# Throw an error here if output_logprobs is None
if output_logprobs is None:
# Check if output_logprobs is properly populated
if (
output_logprobs is None
or not output_logprobs
or len(output_logprobs) == 0
):
raise RuntimeError(
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
"This usually indicates a problem with the scoring request or the backend output."
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
"This indicates token_ids_logprobs were not computed properly for the scoring request."
)
for logprob, token_id, _ in output_logprobs[0]:
......
......@@ -259,6 +259,15 @@ class TpModelWorker:
if skip_sample:
next_token_ids = None
# For prefill-only requests, we still need to compute logprobs even when sampling is skipped
if (
model_worker_batch.is_prefill_only
and model_worker_batch.return_logprob
):
# Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
)
else:
next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch
......
......@@ -174,21 +174,28 @@ class TpModelWorkerClient:
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation(
model_worker_batch, model_worker_batch.launch_done
model_worker_batch,
model_worker_batch.launch_done,
# Skip sampling for prefill-only requests
skip_sample=model_worker_batch.is_prefill_only,
)
)
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
] = next_token_ids
# Copy results to the CPU
if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
......@@ -197,7 +204,9 @@ class TpModelWorkerClient:
logits_output.hidden_states = logits_output.hidden_states.to(
"cpu", non_blocking=True
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
# Only copy to CPU if not already on CPU
if next_token_ids.device.type != "cpu":
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()
self.output_queue.put(
......@@ -221,10 +230,10 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids, can_run_cuda_graph
......
......@@ -2158,6 +2158,38 @@ class ModelRunner:
)
return next_token_ids
def compute_logprobs_only(
self,
logits_output: LogitsProcessorOutput,
forward_batch: ForwardBatch,
) -> None:
"""
Compute token_ids_logprobs without performing sampling.
Optimized path for prefill-only requests that need token_ids_logprobs but don't
require next token generation. Skips expensive sampling operations
while still providing requested probability information.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
"""
if not forward_batch.token_ids_logprobs:
return
# Preprocess logits (same as in sample method)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Delegate to sampler for logprob-only computation
# This populates logits_output with requested token probabilities
self.sampler.compute_logprobs_only(
logits_output,
forward_batch.sampling_info,
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment