"src/diffusers/models/controlnets/controlnet_flax.py" did not exist on "df91c44712381c021c0f4855a623b1a1c32f28b7"
Unverified Commit a360511d authored by Sundara Raman Ramachandran's avatar Sundara Raman Ramachandran Committed by GitHub
Browse files

[Generative Score API] Scoring(Prefill-only) optimizations. (#9748)

parent 94d0f656
...@@ -72,7 +72,10 @@ class LogitsProcessorOutput: ...@@ -72,7 +72,10 @@ class LogitsProcessorOutput:
next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: Optional[List] = None # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
next_token_token_ids_logprobs_val: Optional[
List[Union[List[float], torch.Tensor]]
] = None
next_token_token_ids_logprobs_idx: Optional[List] = None next_token_token_ids_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
......
import logging import logging
from typing import List from typing import List, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -39,6 +39,25 @@ class Sampler(nn.Module): ...@@ -39,6 +39,25 @@ class Sampler(nn.Module):
if is_dp_attention_enabled(): if is_dp_attention_enabled():
self.tp_sync_group = get_attention_tp_group().device_group self.tp_sync_group = get_attention_tp_group().device_group
def _preprocess_logits(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
) -> torch.Tensor:
"""Apply custom logit processors and handle NaN detection."""
# Apply the custom logit processors if registered in the sampling info
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(logits, sampling_info)
# Detect and handle NaN values in logits
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
return logits
def forward( def forward(
self, self,
logits_output: LogitsProcessorOutput, logits_output: LogitsProcessorOutput,
...@@ -61,17 +80,8 @@ class Sampler(nn.Module): ...@@ -61,17 +80,8 @@ class Sampler(nn.Module):
""" """
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
# Apply the custom logit processors if registered in the sampling info. # Preprocess logits (custom processors and NaN handling)
if sampling_info.has_custom_logit_processor: logits = self._preprocess_logits(logits, sampling_info)
apply_custom_logit_processor(logits, sampling_info)
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
...@@ -165,6 +175,54 @@ class Sampler(nn.Module): ...@@ -165,6 +175,54 @@ class Sampler(nn.Module):
return batch_next_token_ids return batch_next_token_ids
def compute_logprobs_only(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
) -> None:
"""
Compute logprobs for requested token IDs without performing sampling.
Optimized for prefill-only scoring requests that need token probabilities
but don't require next token generation.
"""
if logits_output.next_token_logits is None:
logger.warning("No logits available for logprob computation")
return
# Check if any requests actually need logprobs computation
needs_token_ids_logprobs = any(
token_ids is not None and len(token_ids) > 0
for token_ids in token_ids_logprobs
)
needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
if not (needs_token_ids_logprobs or needs_top_logprobs):
return
# Preprocess logits (custom processors and NaN handling)
logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
# Compute logprobs
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
# Handle top logprobs if requested
if needs_top_logprobs:
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums)
# Handle token_ids logprobs if requested
if needs_token_ids_logprobs:
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
def top_k_top_p_min_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor, probs: torch.Tensor,
...@@ -234,10 +292,95 @@ def get_top_logprobs( ...@@ -234,10 +292,95 @@ def get_top_logprobs(
) )
def get_token_ids_logprobs( def get_token_ids_logprobs_batch_optimized(
logprobs: torch.Tensor, logprobs: torch.Tensor,
token_ids_logprobs: List[List[int]], token_ids_logprobs: List[List[int]],
): ) -> Tuple[List, List]:
"""
Vectorized batch processing for token ID logprobs extraction.
Uses a single GPU kernel call for the entire batch instead of multiple
separate calls, significantly improving performance for large batches.
Args:
logprobs: Log probabilities tensor [batch_size, vocab_size]
token_ids_logprobs: List of token IDs to extract logprobs for
Example:
# Input: batch_size=3, vocab_size=5
logprobs = torch.tensor([
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
])
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
# Output:
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
# indices = [[1, 3], [2], [0, 2, 4]]
"""
batch_size = len(token_ids_logprobs)
device = logprobs.device
# Step 1: Calculate lengths for each request, treating None as empty list
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
token_lengths = torch.tensor(
[len(token_ids or []) for token_ids in token_ids_logprobs], device=device
)
total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
# Handle edge case where no tokens are requested
if total_tokens == 0:
return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
[] for _ in token_ids_logprobs
]
# Step 2: Build flattened indices using torch operations
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
row_indices = torch.repeat_interleave(
torch.arange(batch_size, device=device), token_lengths
)
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
col_indices = torch.tensor(
[
token_id
for token_ids in token_ids_logprobs
for token_id in (token_ids or [])
],
device=device,
dtype=torch.long,
)
# Step 3: Single vectorized gather operation
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
gathered_logprobs = logprobs[row_indices, col_indices]
# Step 4: Split results back per request using torch operations
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
split_logprobs = torch.split_with_sizes(
gathered_logprobs, token_lengths.tolist(), dim=0
)
# Step 5: Format output to match expected return structure
# Example: Convert split tensors back to list format with proper empty handling
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
# i=1: [2] -> append split_logprobs[1] and [2]
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None and len(token_ids) > 0:
output_token_ids_logprobs_val.append(split_logprobs[i])
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append(logprobs.new_empty(0))
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
output_token_ids_logprobs_val = [] output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = [] output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs): for i, token_ids in enumerate(token_ids_logprobs):
......
...@@ -561,7 +561,10 @@ class Req: ...@@ -561,7 +561,10 @@ class Req:
# shape: (bs, k) # shape: (bs, k)
self.output_top_logprobs_val = [] self.output_top_logprobs_val = []
self.output_top_logprobs_idx = [] self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = [] # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
self.output_token_ids_logprobs_val: List[
Union[List[float], torch.Tensor]
] = []
self.output_token_ids_logprobs_idx = [] self.output_token_ids_logprobs_idx = []
else: else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = ( self.output_token_logprobs_val = self.output_token_logprobs_idx = (
...@@ -619,6 +622,11 @@ class Req: ...@@ -619,6 +622,11 @@ class Req:
def seqlen(self): def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids) return len(self.origin_input_ids) + len(self.output_ids)
@property
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None: if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs self.multimodal_inputs = image_inputs
...@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs), return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all( is_prefill_only=all(req.is_prefill_only for req in reqs),
req.sampling_params.max_new_tokens == 0 for req in reqs
),
chunked_req=chunked_req, chunked_req=chunked_req,
) )
...@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.is_retracted = False req.is_retracted = False
# Compute the relative logprob_start_len in an extend batch # Compute the relative logprob_start_len in an extend batch
#
# Key variables:
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
# - extend_input_len: Number of tokens that need to be processed in this extend batch
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if req.logprob_start_len >= pre_len: if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min( # Optimization for prefill-only requests: When we only need logprobs at
req.logprob_start_len - pre_len, # positions beyond the input sequence (to score next-token likelihood), skip all
req.extend_input_len, # input logprob computation during prefill since no generation will occur.
req.seqlen - 1, if self.is_prefill_only and req.logprob_start_len == len(
) req.origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req.extend_logprob_start_len = req.extend_input_len
else:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else: else:
# logprob_start_len is before the current extend batch, so start from beginning
req.extend_logprob_start_len = 0 req.extend_logprob_start_len = 0
if self.return_logprob: if self.return_logprob:
...@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
), ),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done, launch_done=self.launch_done,
is_prefill_only=self.is_prefill_only,
) )
def copy(self): def copy(self):
...@@ -1905,6 +1935,9 @@ class ModelWorkerBatch: ...@@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
# Overlap event # Overlap event
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
@triton.jit @triton.jit
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
......
...@@ -1261,11 +1261,19 @@ class Scheduler( ...@@ -1261,11 +1261,19 @@ class Scheduler(
# Copy more attributes # Copy more attributes
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob: if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens # By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1 # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
# to skip input logprob computation entirely
if req.is_prefill_only:
req.logprob_start_len = len(req.origin_input_ids)
else:
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
req.logprob_start_len = len(req.origin_input_ids) - 1
else: else:
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids): if not req.is_prefill_only and req.logprob_start_len >= len(
req.origin_input_ids
):
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len." error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
req.set_finish_with_abort(error_msg) req.set_finish_with_abort(error_msg)
......
...@@ -5,6 +5,8 @@ import threading ...@@ -5,6 +5,8 @@ import threading
import time import time
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
...@@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin: ...@@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin:
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted: if req.is_retracted:
continue continue
...@@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin: ...@@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin:
extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i] extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len num_input_logprobs = extend_input_len - extend_logprob_start_len
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values( self.add_logprob_return_values(
i, i,
...@@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin: ...@@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin:
output: LogitsProcessorOutput, output: LogitsProcessorOutput,
): ):
"""Attach logprobs to the return values.""" """Attach logprobs to the return values."""
req.output_token_logprobs_val.append(output.next_token_logprobs[i]) if output.next_token_logprobs is not None:
req.output_token_logprobs_idx.append(next_token_ids[i]) req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True # Only add input logprobs if there are input tokens to process
) # Note: For prefill-only requests with default logprob_start_len, this will be 0,
# meaning we only compute output logprobs (which is the intended behavior)
if num_input_logprobs > 0:
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
else:
self._initialize_empty_logprob_containers(req)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None: if (
req.output_token_ids_logprobs_val.append( req.token_ids_logprob is not None
output.next_token_token_ids_logprobs_val[i] and output.next_token_token_ids_logprobs_val is not None
) ):
# Convert GPU tensor to list if needed
logprobs_val = output.next_token_token_ids_logprobs_val[i]
if isinstance(logprobs_val, torch.Tensor):
logprobs_val = logprobs_val.tolist()
req.output_token_ids_logprobs_val.append(logprobs_val)
req.output_token_ids_logprobs_idx.append( req.output_token_ids_logprobs_idx.append(
output.next_token_token_ids_logprobs_idx[i] output.next_token_token_ids_logprobs_idx[i]
) )
return num_input_logprobs return num_input_logprobs
def _initialize_empty_logprob_containers(self, req: Req) -> None:
"""
Initialize logprob fields to empty lists if unset.
This is needed for prefill-only requests where the normal initialization
flow might be bypassed, but downstream code expects these fields to be lists.
"""
if req.input_token_logprobs_val is None:
req.input_token_logprobs_val = []
if req.input_token_logprobs_idx is None:
req.input_token_logprobs_idx = []
if req.input_top_logprobs_val is None:
req.input_top_logprobs_val = []
if req.input_top_logprobs_idx is None:
req.input_top_logprobs_idx = []
if req.input_token_ids_logprobs_val is None:
req.input_token_ids_logprobs_val = []
if req.input_token_ids_logprobs_idx is None:
req.input_token_ids_logprobs_idx = []
def stream_output( def stream_output(
self: Scheduler, self: Scheduler,
reqs: List[Req], reqs: List[Req],
......
...@@ -1778,11 +1778,15 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1778,11 +1778,15 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# the next position after the last token in the prompt # the next position after the last token in the prompt
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
# Throw an error here if output_logprobs is None # Check if output_logprobs is properly populated
if output_logprobs is None: if (
output_logprobs is None
or not output_logprobs
or len(output_logprobs) == 0
):
raise RuntimeError( raise RuntimeError(
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. " f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
"This usually indicates a problem with the scoring request or the backend output." "This indicates token_ids_logprobs were not computed properly for the scoring request."
) )
for logprob, token_id, _ in output_logprobs[0]: for logprob, token_id, _ in output_logprobs[0]:
......
...@@ -259,6 +259,15 @@ class TpModelWorker: ...@@ -259,6 +259,15 @@ class TpModelWorker:
if skip_sample: if skip_sample:
next_token_ids = None next_token_ids = None
# For prefill-only requests, we still need to compute logprobs even when sampling is skipped
if (
model_worker_batch.is_prefill_only
and model_worker_batch.return_logprob
):
# Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
)
else: else:
next_token_ids = self.model_runner.sample( next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch logits_output, model_worker_batch
......
...@@ -174,21 +174,28 @@ class TpModelWorkerClient: ...@@ -174,21 +174,28 @@ class TpModelWorkerClient:
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation( self.worker.forward_batch_generation(
model_worker_batch, model_worker_batch.launch_done model_worker_batch,
model_worker_batch.launch_done,
# Skip sampling for prefill-only requests
skip_sample=model_worker_batch.is_prefill_only,
) )
) )
# Update the future token ids map # Update the future token ids map
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
self.future_token_ids_map[ self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
] = next_token_ids ] = next_token_ids
# Copy results to the CPU # Copy results to the CPU
if model_worker_batch.return_logprob: if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = ( if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs.to("cpu", non_blocking=True) logits_output.next_token_logprobs = (
) logits_output.next_token_logprobs.to("cpu", non_blocking=True)
)
if logits_output.input_token_logprobs is not None: if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True) logits_output.input_token_logprobs.to("cpu", non_blocking=True)
...@@ -197,7 +204,9 @@ class TpModelWorkerClient: ...@@ -197,7 +204,9 @@ class TpModelWorkerClient:
logits_output.hidden_states = logits_output.hidden_states.to( logits_output.hidden_states = logits_output.hidden_states.to(
"cpu", non_blocking=True "cpu", non_blocking=True
) )
next_token_ids = next_token_ids.to("cpu", non_blocking=True) # Only copy to CPU if not already on CPU
if next_token_ids.device.type != "cpu":
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record() copy_done.record()
self.output_queue.put( self.output_queue.put(
...@@ -221,10 +230,10 @@ class TpModelWorkerClient: ...@@ -221,10 +230,10 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist() logits_output.next_token_logprobs.tolist()
) )
if logits_output.input_token_logprobs is not None: if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple( logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids, can_run_cuda_graph return logits_output, next_token_ids, can_run_cuda_graph
......
...@@ -2158,6 +2158,38 @@ class ModelRunner: ...@@ -2158,6 +2158,38 @@ class ModelRunner:
) )
return next_token_ids return next_token_ids
def compute_logprobs_only(
self,
logits_output: LogitsProcessorOutput,
forward_batch: ForwardBatch,
) -> None:
"""
Compute token_ids_logprobs without performing sampling.
Optimized path for prefill-only requests that need token_ids_logprobs but don't
require next token generation. Skips expensive sampling operations
while still providing requested probability information.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
"""
if not forward_batch.token_ids_logprobs:
return
# Preprocess logits (same as in sample method)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Delegate to sampler for logprob-only computation
# This populates logits_output with requested token probabilities
self.sampler.compute_logprobs_only(
logits_output,
forward_batch.sampling_info,
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
@property @property
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type. """Detect if the model has "mrope" rope_scaling type.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment