Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a027a9b4
"tests/python/common/test_merge.py" did not exist on "98325b1097877b93dc872727d22ce2f402666e8f"
Unverified
Commit
a027a9b4
authored
Aug 13, 2025
by
Sundara Raman Ramachandran
Committed by
GitHub
Aug 14, 2025
Browse files
[Generative Score API] Optimization to Remove Decode. (#8840)
parent
9e426466
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
843 additions
and
20 deletions
+843
-20
benchmark/score/bench_score.py
benchmark/score/bench_score.py
+603
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+27
-18
test/srt/test_score_api.py
test/srt/test_score_api.py
+82
-0
test/srt/test_tokenizer_batch_encode.py
test/srt/test_tokenizer_batch_encode.py
+121
-0
No files found.
benchmark/score/bench_score.py
0 → 100644
View file @
a027a9b4
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/schedule_batch.py
View file @
a027a9b4
...
@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
# Whether to return hidden states
return_hidden_states
:
bool
=
False
return_hidden_states
:
bool
=
False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only
:
bool
=
False
# hicache pointer for synchronizing data loading from CPU to GPU
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index
:
int
=
0
hicache_consumer_index
:
int
=
0
...
@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
spec_algorithm
,
spec_algorithm
=
spec_algorithm
,
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
is_prefill_only
=
all
(
req
.
sampling_params
.
max_new_tokens
==
0
for
req
in
reqs
),
chunked_req
=
chunked_req
,
chunked_req
=
chunked_req
,
)
)
...
@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
is_extend_in_batch
=
self
.
is_extend_in_batch
,
is_extend_in_batch
=
self
.
is_extend_in_batch
,
is_prefill_only
=
self
.
is_prefill_only
,
)
)
def
_evict_tree_cache_if_needed
(
self
,
num_tokens
:
int
):
def
_evict_tree_cache_if_needed
(
self
,
num_tokens
:
int
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
a027a9b4
...
@@ -1466,8 +1466,9 @@ class Scheduler(
...
@@ -1466,8 +1466,9 @@ class Scheduler(
if
self
.
last_batch
.
batch_size
()
<
last_bs
:
if
self
.
last_batch
.
batch_size
()
<
last_bs
:
self
.
running_batch
.
batch_is_full
=
False
self
.
running_batch
.
batch_is_full
=
False
# Merge the new batch into the running batch
# Merge the new batch into the running batch.
if
not
self
.
last_batch
.
is_empty
():
# For prefill-only batch, we can avoid going through decoding step.
if
not
self
.
last_batch
.
is_empty
()
and
not
self
.
last_batch
.
is_prefill_only
:
if
self
.
running_batch
.
is_empty
():
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
self
.
last_batch
self
.
running_batch
=
self
.
last_batch
else
:
else
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
a027a9b4
...
@@ -699,7 +699,7 @@ class TokenizerManager:
...
@@ -699,7 +699,7 @@ class TokenizerManager:
# Process all requests
# Process all requests
tokenized_objs
=
[]
tokenized_objs
=
[]
for
i
,
req
in
enumerate
(
requests
):
for
i
,
req
in
enumerate
(
requests
):
self
.
_validate_
token_len
(
obj
[
i
],
input_ids_list
[
i
])
self
.
_validate_
one_request
(
obj
[
i
],
input_ids_list
[
i
])
tokenized_objs
.
append
(
tokenized_objs
.
append
(
self
.
_create_tokenized_object
(
self
.
_create_tokenized_object
(
req
,
req
.
text
,
input_ids_list
[
i
],
None
,
None
req
,
req
.
text
,
input_ids_list
[
i
],
None
,
None
...
@@ -1892,6 +1892,13 @@ class TokenizerManager:
...
@@ -1892,6 +1892,13 @@ class TokenizerManager:
f
"Token ID
{
token_id
}
is out of vocabulary (vocab size:
{
vocab_size
}
)"
f
"Token ID
{
token_id
}
is out of vocabulary (vocab size:
{
vocab_size
}
)"
)
)
batch_request
=
GenerateReqInput
(
token_ids_logprob
=
label_token_ids
,
return_logprob
=
True
,
stream
=
False
,
sampling_params
=
{
"max_new_tokens"
:
0
},
)
# Handle string or tokenized query/items
# Handle string or tokenized query/items
if
isinstance
(
query
,
str
)
and
(
if
isinstance
(
query
,
str
)
and
(
isinstance
(
items
,
str
)
isinstance
(
items
,
str
)
...
@@ -1903,13 +1910,9 @@ class TokenizerManager:
...
@@ -1903,13 +1910,9 @@ class TokenizerManager:
prompts
=
[
f
"
{
item
}{
query
}
"
for
item
in
items_list
]
prompts
=
[
f
"
{
item
}{
query
}
"
for
item
in
items_list
]
else
:
else
:
prompts
=
[
f
"
{
query
}{
item
}
"
for
item
in
items_list
]
prompts
=
[
f
"
{
query
}{
item
}
"
for
item
in
items_list
]
batch_request
=
GenerateReqInput
(
text
=
prompts
,
batch_request
.
text
=
prompts
return_logprob
=
True
,
token_ids_logprob
=
label_token_ids
,
stream
=
False
,
sampling_params
=
{
"max_new_tokens"
:
1
},
)
elif
(
elif
(
isinstance
(
query
,
list
)
isinstance
(
query
,
list
)
and
isinstance
(
items
,
list
)
and
isinstance
(
items
,
list
)
...
@@ -1921,13 +1924,8 @@ class TokenizerManager:
...
@@ -1921,13 +1924,8 @@ class TokenizerManager:
input_ids_list
=
[
item
+
query
for
item
in
items
]
input_ids_list
=
[
item
+
query
for
item
in
items
]
else
:
else
:
input_ids_list
=
[
query
+
item
for
item
in
items
]
input_ids_list
=
[
query
+
item
for
item
in
items
]
batch_request
=
GenerateReqInput
(
input_ids
=
input_ids_list
,
batch_request
.
input_ids
=
input_ids_list
return_logprob
=
True
,
token_ids_logprob
=
label_token_ids
,
stream
=
False
,
sampling_params
=
{
"max_new_tokens"
:
1
},
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid combination of query/items types for score_request."
"Invalid combination of query/items types for score_request."
...
@@ -1939,9 +1937,20 @@ class TokenizerManager:
...
@@ -1939,9 +1937,20 @@ class TokenizerManager:
for
result
in
results
:
for
result
in
results
:
# Get logprobs for each token
# Get logprobs for each token
logprobs
=
{}
logprobs
=
{}
for
logprob
,
token_id
,
_
in
result
[
"meta_info"
].
get
(
"output_token_ids_logprobs"
,
[]
# For scoring requests, we read from output_token_ids_logprobs since we want
)[
0
]:
# the logprobs for specific tokens mentioned in the label_token_ids at
# the next position after the last token in the prompt
output_logprobs
=
result
[
"meta_info"
].
get
(
"output_token_ids_logprobs"
,
[])
# Throw an error here if output_logprobs is None
if
output_logprobs
is
None
:
raise
RuntimeError
(
f
"output_logprobs is None for request
{
result
[
'meta_info'
].
get
(
'id'
,
'<unknown>'
)
}
. "
"This usually indicates a problem with the scoring request or the backend output."
)
for
logprob
,
token_id
,
_
in
output_logprobs
[
0
]:
if
token_id
in
label_token_ids
:
if
token_id
in
label_token_ids
:
logprobs
[
token_id
]
=
logprob
logprobs
[
token_id
]
=
logprob
...
...
test/srt/test_score_api.py
View file @
a027a9b4
...
@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
...
@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
1.0
,
sum
(
score_list
),
6
,
"Scores should sum to 1"
1.0
,
sum
(
score_list
),
6
,
"Scores should sum to 1"
)
)
def
test_score_request_construction
(
self
):
"""Test that scoring requests are constructed to avoid decode phase."""
from
unittest.mock
import
patch
# Capture the internal request to verify optimization
captured_requests
=
[]
original_gen
=
self
.
engine
.
tokenizer_manager
.
generate_request
async
def
mock_generate_request
(
req
,
request
=
None
):
captured_requests
.
append
(
req
)
async
for
result
in
original_gen
(
req
,
request
):
yield
result
# Patch the generate_request method
with
patch
.
object
(
self
.
engine
.
tokenizer_manager
,
"generate_request"
,
side_effect
=
mock_generate_request
,
):
# Run a scoring request
query
=
"What is the capital of"
items
=
[
"France"
,
"Germany"
]
label_token_ids
=
[
1
,
2
,
3
]
scores
=
self
.
engine
.
score
(
query
=
query
,
items
=
items
,
label_token_ids
=
label_token_ids
,
apply_softmax
=
True
,
)
# Verify we got results
self
.
assertEqual
(
len
(
scores
),
len
(
items
))
# Verify the captured request has decode-avoiding properties
self
.
assertEqual
(
len
(
captured_requests
),
1
)
request
=
captured_requests
[
0
]
# Key assertions for decode phase avoidance:
# 1. max_new_tokens should be 0 (prevents token generation)
# Handle both single and batch request cases
if
isinstance
(
request
.
sampling_params
,
dict
):
max_new_tokens
=
request
.
sampling_params
.
get
(
"max_new_tokens"
,
0
)
elif
isinstance
(
request
.
sampling_params
,
list
):
# For batch requests, check the first item
max_new_tokens
=
request
.
sampling_params
[
0
].
get
(
"max_new_tokens"
,
0
)
else
:
max_new_tokens
=
getattr
(
request
.
sampling_params
,
"max_new_tokens"
,
0
)
self
.
assertEqual
(
max_new_tokens
,
0
,
"max_new_tokens should be 0 to avoid decode phase"
)
# 2. Should have token_ids_logprob for scoring
# Handle both single and batch request cases
if
(
isinstance
(
request
.
token_ids_logprob
,
list
)
and
len
(
request
.
token_ids_logprob
)
>
0
and
isinstance
(
request
.
token_ids_logprob
[
0
],
list
)
):
# Batch case: token_ids_logprob is a list of lists
# Each item in the batch should have the same label_token_ids
for
item_token_ids
in
request
.
token_ids_logprob
:
self
.
assertEqual
(
item_token_ids
,
label_token_ids
,
"Each batch item should have label_token_ids for scoring"
,
)
else
:
# Single request case
self
.
assertEqual
(
request
.
token_ids_logprob
,
label_token_ids
,
"Should have label_token_ids for scoring"
,
)
# 3. Should request logprobs but not stream
self
.
assertTrue
(
request
.
return_logprob
,
"Should request logprobs for scoring"
)
self
.
assertFalse
(
request
.
stream
,
"Scoring requests should not stream"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
test/srt/test_tokenizer_batch_encode.py
0 → 100644
View file @
a027a9b4
"""
Unit tests for enable_tokenizer_batch_encode feature.
This tests the batch tokenization functionality which allows processing
multiple text inputs in a single batch for improved performance.
Usage:
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path
"""
import
asyncio
import
unittest
from
typing
import
List
from
unittest.mock
import
AsyncMock
,
Mock
,
call
,
patch
from
sglang.srt.managers.io_struct
import
GenerateReqInput
,
TokenizedGenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class
TestTokenizerBatchEncode
(
unittest
.
TestCase
):
"""Test cases for tokenizer batch encoding validation and setup."""
def
setUp
(
self
):
"""Set up test fixtures."""
self
.
server_args
=
ServerArgs
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
enable_tokenizer_batch_encode
=
True
,
)
self
.
port_args
=
PortArgs
.
init_new
(
self
.
server_args
)
with
patch
(
"zmq.asyncio.Context"
),
patch
(
"sglang.srt.utils.get_zmq_socket"
),
patch
(
"sglang.srt.hf_transformers_utils.get_tokenizer"
)
as
mock_tokenizer
:
mock_tokenizer
.
return_value
=
Mock
(
vocab_size
=
32000
)
self
.
tokenizer_manager
=
TokenizerManager
(
self
.
server_args
,
self
.
port_args
)
def
test_batch_encode_enabled
(
self
):
"""Test that batch encoding is enabled when configured."""
self
.
assertTrue
(
self
.
server_args
.
enable_tokenizer_batch_encode
)
def
test_batch_encode_disabled
(
self
):
"""Test that batch encoding can be disabled."""
server_args_disabled
=
ServerArgs
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
enable_tokenizer_batch_encode
=
False
,
)
self
.
assertFalse
(
server_args_disabled
.
enable_tokenizer_batch_encode
)
def
test_multimodal_input_validation
(
self
):
"""Test that multimodal inputs are rejected in batch mode."""
req
=
GenerateReqInput
(
text
=
"test"
,
image_data
=
[
"dummy"
])
req
.
contains_mm_input
=
Mock
(
return_value
=
True
)
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
lambda
self
,
i
:
req
self
.
tokenizer_manager
.
is_generation
=
True
with
self
.
assertRaises
(
ValueError
)
as
cm
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
1
,
batch_obj
)
self
.
assertIn
(
"multimodal"
,
str
(
cm
.
exception
))
def
test_pretokenized_input_validation
(
self
):
"""Test that pre-tokenized inputs are rejected in batch mode."""
req
=
GenerateReqInput
(
input_ids
=
[
1
,
2
,
3
])
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
lambda
self
,
i
:
req
with
self
.
assertRaises
(
ValueError
)
as
cm
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
1
,
batch_obj
)
self
.
assertIn
(
"pre-tokenized"
,
str
(
cm
.
exception
))
def
test_input_embeds_validation
(
self
):
"""Test that input embeds are rejected in batch mode."""
req
=
GenerateReqInput
(
input_embeds
=
[
0.1
,
0.2
])
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
lambda
self
,
i
:
req
with
self
.
assertRaises
(
ValueError
)
as
cm
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
1
,
batch_obj
)
self
.
assertIn
(
"input_embeds"
,
str
(
cm
.
exception
))
def
test_valid_text_only_requests_pass_validation
(
self
):
"""Test that valid text-only requests pass validation."""
# Create valid requests (text-only)
requests
=
[]
for
i
in
range
(
3
):
req
=
GenerateReqInput
(
text
=
f
"test text
{
i
}
"
)
req
.
contains_mm_input
=
Mock
(
return_value
=
False
)
requests
.
append
(
req
)
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
Mock
(
side_effect
=
lambda
i
:
requests
[
i
])
# Should not raise any exception
try
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
3
,
batch_obj
)
except
Exception
as
e
:
self
.
fail
(
f
"Validation failed for valid text-only requests:
{
e
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment