Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a360511d
"src/diffusers/models/controlnets/controlnet_flax.py" did not exist on "df91c44712381c021c0f4855a623b1a1c32f28b7"
Unverified
Commit
a360511d
authored
Sep 13, 2025
by
Sundara Raman Ramachandran
Committed by
GitHub
Sep 14, 2025
Browse files
[Generative Score API] Scoring(Prefill-only) optimizations. (#9748)
parent
94d0f656
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
326 additions
and
49 deletions
+326
-49
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+157
-14
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+42
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-2
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+46
-10
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+8
-4
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+9
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+18
-9
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+32
-0
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
a360511d
...
@@ -72,7 +72,10 @@ class LogitsProcessorOutput:
...
@@ -72,7 +72,10 @@ class LogitsProcessorOutput:
next_token_top_logprobs_val
:
Optional
[
List
]
=
None
next_token_top_logprobs_val
:
Optional
[
List
]
=
None
next_token_top_logprobs_idx
:
Optional
[
List
]
=
None
next_token_top_logprobs_idx
:
Optional
[
List
]
=
None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val
:
Optional
[
List
]
=
None
# Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
next_token_token_ids_logprobs_val
:
Optional
[
List
[
Union
[
List
[
float
],
torch
.
Tensor
]]
]
=
None
next_token_token_ids_logprobs_idx
:
Optional
[
List
]
=
None
next_token_token_ids_logprobs_idx
:
Optional
[
List
]
=
None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
...
...
python/sglang/srt/layers/sampler.py
View file @
a360511d
import
logging
import
logging
from
typing
import
List
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
...
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
if
is_dp_attention_enabled
():
if
is_dp_attention_enabled
():
self
.
tp_sync_group
=
get_attention_tp_group
().
device_group
self
.
tp_sync_group
=
get_attention_tp_group
().
device_group
def
_preprocess_logits
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
)
->
torch
.
Tensor
:
"""Apply custom logit processors and handle NaN detection."""
# Apply the custom logit processors if registered in the sampling info
if
sampling_info
.
has_custom_logit_processor
:
apply_custom_logit_processor
(
logits
,
sampling_info
)
# Detect and handle NaN values in logits
if
self
.
use_nan_detection
and
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logits
=
torch
.
where
(
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
)
if
crash_on_warnings
():
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
return
logits
def
forward
(
def
forward
(
self
,
self
,
logits_output
:
LogitsProcessorOutput
,
logits_output
:
LogitsProcessorOutput
,
...
@@ -61,17 +80,8 @@ class Sampler(nn.Module):
...
@@ -61,17 +80,8 @@ class Sampler(nn.Module):
"""
"""
logits
=
logits_output
.
next_token_logits
logits
=
logits_output
.
next_token_logits
# Apply the custom logit processors if registered in the sampling info.
# Preprocess logits (custom processors and NaN handling)
if
sampling_info
.
has_custom_logit_processor
:
logits
=
self
.
_preprocess_logits
(
logits
,
sampling_info
)
apply_custom_logit_processor
(
logits
,
sampling_info
)
if
self
.
use_nan_detection
and
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logits
=
torch
.
where
(
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
)
if
crash_on_warnings
():
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
if
sampling_info
.
is_all_greedy
:
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
# Use torch.argmax if all requests use greedy sampling
...
@@ -165,6 +175,54 @@ class Sampler(nn.Module):
...
@@ -165,6 +175,54 @@ class Sampler(nn.Module):
return
batch_next_token_ids
return
batch_next_token_ids
def
compute_logprobs_only
(
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
,
return_logprob
:
bool
,
top_logprobs_nums
:
List
[
int
],
token_ids_logprobs
:
List
[
List
[
int
]],
)
->
None
:
"""
Compute logprobs for requested token IDs without performing sampling.
Optimized for prefill-only scoring requests that need token probabilities
but don't require next token generation.
"""
if
logits_output
.
next_token_logits
is
None
:
logger
.
warning
(
"No logits available for logprob computation"
)
return
# Check if any requests actually need logprobs computation
needs_token_ids_logprobs
=
any
(
token_ids
is
not
None
and
len
(
token_ids
)
>
0
for
token_ids
in
token_ids_logprobs
)
needs_top_logprobs
=
any
(
x
>
0
for
x
in
top_logprobs_nums
)
if
not
(
needs_token_ids_logprobs
or
needs_top_logprobs
):
return
# Preprocess logits (custom processors and NaN handling)
logits
=
self
.
_preprocess_logits
(
logits_output
.
next_token_logits
,
sampling_info
)
# Compute logprobs
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
# Handle top logprobs if requested
if
needs_top_logprobs
:
(
logits_output
.
next_token_top_logprobs_val
,
logits_output
.
next_token_top_logprobs_idx
,
)
=
get_top_logprobs
(
logprobs
,
top_logprobs_nums
)
# Handle token_ids logprobs if requested
if
needs_token_ids_logprobs
:
(
logits_output
.
next_token_token_ids_logprobs_val
,
logits_output
.
next_token_token_ids_logprobs_idx
,
)
=
get_token_ids_logprobs_batch_optimized
(
logprobs
,
token_ids_logprobs
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
...
@@ -234,10 +292,95 @@ def get_top_logprobs(
...
@@ -234,10 +292,95 @@ def get_top_logprobs(
)
)
def
get_token_ids_logprobs
(
def
get_token_ids_logprobs
_batch_optimized
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
token_ids_logprobs
:
List
[
List
[
int
]],
token_ids_logprobs
:
List
[
List
[
int
]],
):
)
->
Tuple
[
List
,
List
]:
"""
Vectorized batch processing for token ID logprobs extraction.
Uses a single GPU kernel call for the entire batch instead of multiple
separate calls, significantly improving performance for large batches.
Args:
logprobs: Log probabilities tensor [batch_size, vocab_size]
token_ids_logprobs: List of token IDs to extract logprobs for
Example:
# Input: batch_size=3, vocab_size=5
logprobs = torch.tensor([
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
])
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
# Output:
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
# indices = [[1, 3], [2], [0, 2, 4]]
"""
batch_size
=
len
(
token_ids_logprobs
)
device
=
logprobs
.
device
# Step 1: Calculate lengths for each request, treating None as empty list
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
token_lengths
=
torch
.
tensor
(
[
len
(
token_ids
or
[])
for
token_ids
in
token_ids_logprobs
],
device
=
device
)
total_tokens
=
int
(
token_lengths
.
sum
().
item
())
# 2 + 1 + 3 = 6
# Handle edge case where no tokens are requested
if
total_tokens
==
0
:
return
[
logprobs
.
new_empty
(
0
)
for
_
in
token_ids_logprobs
],
[
[]
for
_
in
token_ids_logprobs
]
# Step 2: Build flattened indices using torch operations
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
row_indices
=
torch
.
repeat_interleave
(
torch
.
arange
(
batch_size
,
device
=
device
),
token_lengths
)
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
col_indices
=
torch
.
tensor
(
[
token_id
for
token_ids
in
token_ids_logprobs
for
token_id
in
(
token_ids
or
[])
],
device
=
device
,
dtype
=
torch
.
long
,
)
# Step 3: Single vectorized gather operation
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
gathered_logprobs
=
logprobs
[
row_indices
,
col_indices
]
# Step 4: Split results back per request using torch operations
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
split_logprobs
=
torch
.
split_with_sizes
(
gathered_logprobs
,
token_lengths
.
tolist
(),
dim
=
0
)
# Step 5: Format output to match expected return structure
# Example: Convert split tensors back to list format with proper empty handling
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
# i=1: [2] -> append split_logprobs[1] and [2]
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
output_token_ids_logprobs_val
=
[]
output_token_ids_logprobs_idx
=
[]
for
i
,
token_ids
in
enumerate
(
token_ids_logprobs
):
if
token_ids
is
not
None
and
len
(
token_ids
)
>
0
:
output_token_ids_logprobs_val
.
append
(
split_logprobs
[
i
])
output_token_ids_logprobs_idx
.
append
(
token_ids
)
else
:
output_token_ids_logprobs_val
.
append
(
logprobs
.
new_empty
(
0
))
output_token_ids_logprobs_idx
.
append
([])
return
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
def
get_token_ids_logprobs
(
logprobs
:
torch
.
Tensor
,
token_ids_logprobs
:
List
[
List
[
int
]]):
output_token_ids_logprobs_val
=
[]
output_token_ids_logprobs_val
=
[]
output_token_ids_logprobs_idx
=
[]
output_token_ids_logprobs_idx
=
[]
for
i
,
token_ids
in
enumerate
(
token_ids_logprobs
):
for
i
,
token_ids
in
enumerate
(
token_ids_logprobs
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
a360511d
...
@@ -561,7 +561,10 @@ class Req:
...
@@ -561,7 +561,10 @@ class Req:
# shape: (bs, k)
# shape: (bs, k)
self
.
output_top_logprobs_val
=
[]
self
.
output_top_logprobs_val
=
[]
self
.
output_top_logprobs_idx
=
[]
self
.
output_top_logprobs_idx
=
[]
self
.
output_token_ids_logprobs_val
=
[]
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
self
.
output_token_ids_logprobs_val
:
List
[
Union
[
List
[
float
],
torch
.
Tensor
]
]
=
[]
self
.
output_token_ids_logprobs_idx
=
[]
self
.
output_token_ids_logprobs_idx
=
[]
else
:
else
:
self
.
output_token_logprobs_val
=
self
.
output_token_logprobs_idx
=
(
self
.
output_token_logprobs_val
=
self
.
output_token_logprobs_idx
=
(
...
@@ -619,6 +622,11 @@ class Req:
...
@@ -619,6 +622,11 @@ class Req:
def
seqlen
(
self
):
def
seqlen
(
self
):
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
@
property
def
is_prefill_only
(
self
)
->
bool
:
"""Check if this request is prefill-only (no token generation needed)."""
return
self
.
sampling_params
.
max_new_tokens
==
0
def
extend_image_inputs
(
self
,
image_inputs
):
def
extend_image_inputs
(
self
,
image_inputs
):
if
self
.
multimodal_inputs
is
None
:
if
self
.
multimodal_inputs
is
None
:
self
.
multimodal_inputs
=
image_inputs
self
.
multimodal_inputs
=
image_inputs
...
@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
spec_algorithm
,
spec_algorithm
=
spec_algorithm
,
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
is_prefill_only
=
all
(
is_prefill_only
=
all
(
req
.
is_prefill_only
for
req
in
reqs
),
req
.
sampling_params
.
max_new_tokens
==
0
for
req
in
reqs
),
chunked_req
=
chunked_req
,
chunked_req
=
chunked_req
,
)
)
...
@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req
.
is_retracted
=
False
req
.
is_retracted
=
False
# Compute the relative logprob_start_len in an extend batch
# Compute the relative logprob_start_len in an extend batch
#
# Key variables:
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
# - extend_input_len: Number of tokens that need to be processed in this extend batch
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if
req
.
logprob_start_len
>=
pre_len
:
if
req
.
logprob_start_len
>=
pre_len
:
req
.
extend_logprob_start_len
=
min
(
# Optimization for prefill-only requests: When we only need logprobs at
req
.
logprob_start_len
-
pre_len
,
# positions beyond the input sequence (to score next-token likelihood), skip all
req
.
extend_input_len
,
# input logprob computation during prefill since no generation will occur.
req
.
seqlen
-
1
,
if
self
.
is_prefill_only
and
req
.
logprob_start_len
==
len
(
)
req
.
origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req
.
extend_logprob_start_len
=
req
.
extend_input_len
else
:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req
.
extend_logprob_start_len
=
min
(
req
.
logprob_start_len
-
pre_len
,
req
.
extend_input_len
,
req
.
seqlen
-
1
,
)
else
:
else
:
# logprob_start_len is before the current extend batch, so start from beginning
req
.
extend_logprob_start_len
=
0
req
.
extend_logprob_start_len
=
0
if
self
.
return_logprob
:
if
self
.
return_logprob
:
...
@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
),
),
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
extend_input_logprob_token_ids
=
self
.
extend_input_logprob_token_ids
,
launch_done
=
self
.
launch_done
,
launch_done
=
self
.
launch_done
,
is_prefill_only
=
self
.
is_prefill_only
,
)
)
def
copy
(
self
):
def
copy
(
self
):
...
@@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
...
@@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
# Overlap event
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
launch_done
:
Optional
[
threading
.
Event
]
=
None
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only
:
bool
=
False
@
triton
.
jit
@
triton
.
jit
def
write_req_to_token_pool_triton
(
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a360511d
...
@@ -1261,11 +1261,19 @@ class Scheduler(
...
@@ -1261,11 +1261,19 @@ class Scheduler(
# Copy more attributes
# Copy more attributes
if
recv_req
.
logprob_start_len
==
-
1
or
not
recv_req
.
return_logprob
:
if
recv_req
.
logprob_start_len
==
-
1
or
not
recv_req
.
return_logprob
:
# By default, only return the logprobs for output tokens
# By default, only return the logprobs for output tokens
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
# to skip input logprob computation entirely
if
req
.
is_prefill_only
:
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
else
:
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
else
:
else
:
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
if
req
.
logprob_start_len
>=
len
(
req
.
origin_input_ids
):
if
not
req
.
is_prefill_only
and
req
.
logprob_start_len
>=
len
(
req
.
origin_input_ids
):
error_msg
=
f
"
{
req
.
logprob_start_len
=
}
is higher than the number of input tokens
{
len
(
req
.
origin_input_ids
)
=
}
. Please use a smaller logprob_start_len."
error_msg
=
f
"
{
req
.
logprob_start_len
=
}
is higher than the number of input tokens
{
len
(
req
.
origin_input_ids
)
=
}
. Please use a smaller logprob_start_len."
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
req
.
set_finish_with_abort
(
error_msg
)
req
.
set_finish_with_abort
(
error_msg
)
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
a360511d
...
@@ -5,6 +5,8 @@ import threading
...
@@ -5,6 +5,8 @@ import threading
import
time
import
time
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
...
@@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin:
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
if
req
.
is_retracted
:
continue
continue
...
@@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin:
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
if
req
.
return_logprob
:
if
req
.
return_logprob
:
self
.
add_logprob_return_values
(
self
.
add_logprob_return_values
(
i
,
i
,
...
@@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin:
...
@@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin:
output
:
LogitsProcessorOutput
,
output
:
LogitsProcessorOutput
,
):
):
"""Attach logprobs to the return values."""
"""Attach logprobs to the return values."""
req
.
output_token_logprobs_val
.
append
(
output
.
next_token_logprobs
[
i
])
if
output
.
next_token_logprobs
is
not
None
:
req
.
output_token_logprobs_idx
.
append
(
next_token_ids
[
i
])
req
.
output_token_logprobs_val
.
append
(
output
.
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_ids
[
i
])
self
.
add_input_logprob_return_values
(
i
,
req
,
output
,
pt
,
num_input_logprobs
,
last_prefill_chunk
=
True
# Only add input logprobs if there are input tokens to process
)
# Note: For prefill-only requests with default logprob_start_len, this will be 0,
# meaning we only compute output logprobs (which is the intended behavior)
if
num_input_logprobs
>
0
:
self
.
add_input_logprob_return_values
(
i
,
req
,
output
,
pt
,
num_input_logprobs
,
last_prefill_chunk
=
True
)
else
:
self
.
_initialize_empty_logprob_containers
(
req
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
output
.
next_token_top_logprobs_val
[
i
])
req
.
output_top_logprobs_val
.
append
(
output
.
next_token_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
next_token_top_logprobs_idx
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
next_token_top_logprobs_idx
[
i
])
if
req
.
token_ids_logprob
is
not
None
:
if
(
req
.
output_token_ids_logprobs_val
.
append
(
req
.
token_ids_logprob
is
not
None
output
.
next_token_token_ids_logprobs_val
[
i
]
and
output
.
next_token_token_ids_logprobs_val
is
not
None
)
):
# Convert GPU tensor to list if needed
logprobs_val
=
output
.
next_token_token_ids_logprobs_val
[
i
]
if
isinstance
(
logprobs_val
,
torch
.
Tensor
):
logprobs_val
=
logprobs_val
.
tolist
()
req
.
output_token_ids_logprobs_val
.
append
(
logprobs_val
)
req
.
output_token_ids_logprobs_idx
.
append
(
req
.
output_token_ids_logprobs_idx
.
append
(
output
.
next_token_token_ids_logprobs_idx
[
i
]
output
.
next_token_token_ids_logprobs_idx
[
i
]
)
)
return
num_input_logprobs
return
num_input_logprobs
def
_initialize_empty_logprob_containers
(
self
,
req
:
Req
)
->
None
:
"""
Initialize logprob fields to empty lists if unset.
This is needed for prefill-only requests where the normal initialization
flow might be bypassed, but downstream code expects these fields to be lists.
"""
if
req
.
input_token_logprobs_val
is
None
:
req
.
input_token_logprobs_val
=
[]
if
req
.
input_token_logprobs_idx
is
None
:
req
.
input_token_logprobs_idx
=
[]
if
req
.
input_top_logprobs_val
is
None
:
req
.
input_top_logprobs_val
=
[]
if
req
.
input_top_logprobs_idx
is
None
:
req
.
input_top_logprobs_idx
=
[]
if
req
.
input_token_ids_logprobs_val
is
None
:
req
.
input_token_ids_logprobs_val
=
[]
if
req
.
input_token_ids_logprobs_idx
is
None
:
req
.
input_token_ids_logprobs_idx
=
[]
def
stream_output
(
def
stream_output
(
self
:
Scheduler
,
self
:
Scheduler
,
reqs
:
List
[
Req
],
reqs
:
List
[
Req
],
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
a360511d
...
@@ -1778,11 +1778,15 @@ class TokenizerManager(TokenizerCommunicatorMixin):
...
@@ -1778,11 +1778,15 @@ class TokenizerManager(TokenizerCommunicatorMixin):
# the next position after the last token in the prompt
# the next position after the last token in the prompt
output_logprobs
=
result
[
"meta_info"
].
get
(
"output_token_ids_logprobs"
,
[])
output_logprobs
=
result
[
"meta_info"
].
get
(
"output_token_ids_logprobs"
,
[])
# Throw an error here if output_logprobs is None
# Check if output_logprobs is properly populated
if
output_logprobs
is
None
:
if
(
output_logprobs
is
None
or
not
output_logprobs
or
len
(
output_logprobs
)
==
0
):
raise
RuntimeError
(
raise
RuntimeError
(
f
"output_logprobs is
None
for request
{
result
[
'meta_info'
].
get
(
'id'
,
'<unknown>'
)
}
. "
f
"output_logprobs is
empty
for request
{
result
[
'meta_info'
].
get
(
'id'
,
'<unknown>'
)
}
. "
"This
usually
indicates
a problem with the scoring request or the backend outpu
t."
"This indicates
token_ids_logprobs were not computed properly for the scoring reques
t."
)
)
for
logprob
,
token_id
,
_
in
output_logprobs
[
0
]:
for
logprob
,
token_id
,
_
in
output_logprobs
[
0
]:
...
...
python/sglang/srt/managers/tp_worker.py
View file @
a360511d
...
@@ -259,6 +259,15 @@ class TpModelWorker:
...
@@ -259,6 +259,15 @@ class TpModelWorker:
if
skip_sample
:
if
skip_sample
:
next_token_ids
=
None
next_token_ids
=
None
# For prefill-only requests, we still need to compute logprobs even when sampling is skipped
if
(
model_worker_batch
.
is_prefill_only
and
model_worker_batch
.
return_logprob
):
# Compute logprobs without full sampling
self
.
model_runner
.
compute_logprobs_only
(
logits_output
,
model_worker_batch
)
else
:
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
logits_output
,
model_worker_batch
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
a360511d
...
@@ -174,21 +174,28 @@ class TpModelWorkerClient:
...
@@ -174,21 +174,28 @@ class TpModelWorkerClient:
# Run forward
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
model_worker_batch
,
model_worker_batch
.
launch_done
,
# Skip sampling for prefill-only requests
skip_sample
=
model_worker_batch
.
is_prefill_only
,
)
)
)
)
# Update the future token ids map
# Update the future token ids map
bs
=
len
(
model_worker_batch
.
seq_lens
)
bs
=
len
(
model_worker_batch
.
seq_lens
)
if
model_worker_batch
.
is_prefill_only
:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
long
)
self
.
future_token_ids_map
[
self
.
future_token_ids_map
[
future_token_ids_ct
+
1
:
future_token_ids_ct
+
bs
+
1
future_token_ids_ct
+
1
:
future_token_ids_ct
+
bs
+
1
]
=
next_token_ids
]
=
next_token_ids
# Copy results to the CPU
# Copy results to the CPU
if
model_worker_batch
.
return_logprob
:
if
model_worker_batch
.
return_logprob
:
logits_output
.
next_token_logprobs
=
(
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
logits_output
.
next_token_logprobs
=
(
)
logits_output
.
next_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
)
if
logits_output
.
input_token_logprobs
is
not
None
:
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
=
(
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
logits_output
.
input_token_logprobs
.
to
(
"cpu"
,
non_blocking
=
True
)
...
@@ -197,7 +204,9 @@ class TpModelWorkerClient:
...
@@ -197,7 +204,9 @@ class TpModelWorkerClient:
logits_output
.
hidden_states
=
logits_output
.
hidden_states
.
to
(
logits_output
.
hidden_states
=
logits_output
.
hidden_states
.
to
(
"cpu"
,
non_blocking
=
True
"cpu"
,
non_blocking
=
True
)
)
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
# Only copy to CPU if not already on CPU
if
next_token_ids
.
device
.
type
!=
"cpu"
:
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
copy_done
.
record
()
copy_done
.
record
()
self
.
output_queue
.
put
(
self
.
output_queue
.
put
(
...
@@ -221,10 +230,10 @@ class TpModelWorkerClient:
...
@@ -221,10 +230,10 @@ class TpModelWorkerClient:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
tolist
()
logits_output
.
next_token_logprobs
.
tolist
()
)
)
if
logits_output
.
input_token_logprobs
is
not
None
:
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
logits_output
.
input_token_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a360511d
...
@@ -2158,6 +2158,38 @@ class ModelRunner:
...
@@ -2158,6 +2158,38 @@ class ModelRunner:
)
)
return
next_token_ids
return
next_token_ids
def
compute_logprobs_only
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
,
)
->
None
:
"""
Compute token_ids_logprobs without performing sampling.
Optimized path for prefill-only requests that need token_ids_logprobs but don't
require next token generation. Skips expensive sampling operations
while still providing requested probability information.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
"""
if
not
forward_batch
.
token_ids_logprobs
:
return
# Preprocess logits (same as in sample method)
self
.
_preprocess_logits
(
logits_output
,
forward_batch
.
sampling_info
)
# Delegate to sampler for logprob-only computation
# This populates logits_output with requested token probabilities
self
.
sampler
.
compute_logprobs_only
(
logits_output
,
forward_batch
.
sampling_info
,
forward_batch
.
return_logprob
,
forward_batch
.
top_logprobs_nums
,
forward_batch
.
token_ids_logprobs
,
)
@
property
@
property
def
model_is_mrope
(
self
)
->
bool
:
def
model_is_mrope
(
self
)
->
bool
:
"""Detect if the model has "mrope" rope_scaling type.
"""Detect if the model has "mrope" rope_scaling type.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment