Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a027a9b4
Unverified
Commit
a027a9b4
authored
Aug 13, 2025
by
Sundara Raman Ramachandran
Committed by
GitHub
Aug 14, 2025
Browse files
[Generative Score API] Optimization to Remove Decode. (#8840)
parent
9e426466
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
843 additions
and
20 deletions
+843
-20
benchmark/score/bench_score.py
benchmark/score/bench_score.py
+603
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+27
-18
test/srt/test_score_api.py
test/srt/test_score_api.py
+82
-0
test/srt/test_tokenizer_batch_encode.py
test/srt/test_tokenizer_batch_encode.py
+121
-0
No files found.
benchmark/score/bench_score.py
0 → 100644
View file @
a027a9b4
"""
SGLang Scoring Benchmark Script
This script benchmarks SGLang's scoring API performance using HTTP requests.
Current Features:
- HTTP-only implementation (open source compatible)
- Uses /v1/score API endpoint directly
- Single item scoring with batching support
- Configurable RPS, duration, and batch sizes
- Progress tracking and detailed metrics
- Poisson and constant request distributions
Usage:
- Update configuration variables at the top of the file
- Ensure SGLang server is running on the configured HTTP_URL
- Run: python bench_score.py
- Each request will contain ITEM_COUNT_VALUES items for batch scoring
"""
import
asyncio
import
concurrent.futures
# For parallel prompt generation
import
json
import
os
import
random
from
statistics
import
mean
import
aiohttp
import
numpy
as
np
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
###############################################################################
# CONFIG
###############################################################################
# Server Configuration
SERVER_TYPE
=
"HTTP"
# Fixed to HTTP for open source
# HTTP Configuration
HTTP_URL
=
"http://localhost:30000/v1/score"
# Use score API directly
# Score API Config
# ITEM_COUNT_VALUES determines number of items per score request (batch size)
SCORE_QUERY_TOKENS
=
120
SCORE_ITEM_TOKENS
=
180
SCORE_MODEL_PATH
=
"Qwen/Qwen3-0.6B"
SCORE_LABEL_TOKEN_IDS
=
[
9454
,
2753
]
# Yes/No token IDs
# Array of RPS values to test
RPS_VALUES
=
[
70
]
# Array of duration values to test
DURATION_SECS_VALUES
=
[
60
]
# Duration values in seconds
# Array of item count values to test
ITEM_COUNT_VALUES
=
[
10
]
# Number of items per request
# Number of unique requests to generate (will be reused)
NUM_UNIQUE_REQUESTS
=
100
DISTRIBUTION
=
"POISSON"
# Options: "CONSTANT", "POISSON"
# Profiling Configuration
PROFILE
=
False
# Enable profiling with START_PROFILE/STOP_PROFILE prompts
# Directory for profiler output
SGLANG_TORCH_PROFILER_DIR
=
"/shared/user/sglang-oss-trace/remove-decode"
if
PROFILE
:
os
.
environ
[
"SGLANG_TORCH_PROFILER_DIR"
]
=
SGLANG_TORCH_PROFILER_DIR
# Special token to replicate for precise token counting
SPECIAL_REPLICATED_TOKEN
=
"<|im_start|>"
###############################################################################
# REQUEST GENERATION (in parallel)
###############################################################################
def
prepare_all_requests_parallel
(
num_requests
,
item_count
):
"""
Generates unique requests in parallel, then reuses them to create the
full request list. Returns a list of str prompts for HTTP.
"""
# Load tokenizer once here to verify special token and get precise counts
print
(
"Loading tokenizer..."
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
SCORE_MODEL_PATH
)
# Verify that our special token produces exactly 1 token
special_token_count
=
len
(
tokenizer
.
encode
(
SPECIAL_REPLICATED_TOKEN
,
add_special_tokens
=
False
)
)
print
(
f
"Special token '
{
SPECIAL_REPLICATED_TOKEN
}
' produces "
f
"
{
special_token_count
}
token(s)"
)
def
generate_text_with_token_count
(
num_toks
):
"""Generate text with precise token count using replicated token."""
if
special_token_count
==
1
:
# Simple case: token maps to exactly 1 token
return
SPECIAL_REPLICATED_TOKEN
*
num_toks
else
:
print
(
f
"Special token '
{
SPECIAL_REPLICATED_TOKEN
}
' produces more than 1 token!!!"
)
# Handle case where special token produces multiple tokens
# Repeat the token enough times to get at least num_toks tokens
repetitions
=
(
num_toks
+
special_token_count
-
1
)
//
special_token_count
text
=
SPECIAL_REPLICATED_TOKEN
*
repetitions
# Verify we got the expected token count (approximately)
actual_tokens
=
len
(
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
))
if
actual_tokens
<
num_toks
:
print
(
f
"Warning: Generated
{
actual_tokens
}
tokens, "
f
"expected
{
num_toks
}
"
)
return
text
def
build_request
(
index
):
"""Build a single request using the shared tokenizer."""
try
:
# Generate query and items for score API
query
=
generate_text_with_token_count
(
SCORE_QUERY_TOKENS
)
items
=
[
generate_text_with_token_count
(
SCORE_ITEM_TOKENS
)
for
_
in
range
(
item_count
)
]
# Return as dict for score API format
score_data
=
{
"query"
:
query
,
"items"
:
items
,
"label_token_ids"
:
SCORE_LABEL_TOKEN_IDS
,
"model"
:
SCORE_MODEL_PATH
,
}
return
(
index
,
score_data
)
except
Exception
as
e
:
print
(
f
"Error building request
{
index
}
:
{
e
}
"
)
return
(
index
,
None
)
# Generate only the unique requests
unique_requests
=
[
None
]
*
NUM_UNIQUE_REQUESTS
# Use ThreadPoolExecutor instead of ProcessPoolExecutor to avoid
# tokenizer loading issues across processes
max_workers
=
min
(
8
,
os
.
cpu_count
()
or
1
)
# Limit to 8 threads max
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
max_workers
)
as
executor
:
futures
=
[]
for
i
in
tqdm
(
range
(
NUM_UNIQUE_REQUESTS
),
desc
=
"Submitting prompt generation tasks"
):
future
=
executor
.
submit
(
build_request
,
i
)
futures
.
append
(
future
)
# Collect results as they complete
for
f
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
desc
=
"Building unique requests"
,
total
=
NUM_UNIQUE_REQUESTS
,
):
try
:
index
,
req_data
=
f
.
result
()
if
req_data
is
not
None
:
unique_requests
[
index
]
=
req_data
else
:
print
(
f
"Failed to build request
{
index
}
"
)
except
Exception
as
e
:
print
(
f
"Error processing request result:
{
e
}
"
)
# Check if we have any valid requests
valid_requests
=
[
req
for
req
in
unique_requests
if
req
is
not
None
]
if
not
valid_requests
:
raise
RuntimeError
(
"Failed to generate any valid requests"
)
print
(
f
"Successfully generated
{
len
(
valid_requests
)
}
out of "
f
"
{
NUM_UNIQUE_REQUESTS
}
unique requests"
)
# Create the full request list by cycling through unique requests
print
(
f
"Reusing
{
len
(
valid_requests
)
}
unique requests to create "
f
"
{
num_requests
}
total requests..."
)
all_requests
=
[]
for
i
in
tqdm
(
range
(
num_requests
),
desc
=
"Reusing requests"
):
unique_index
=
i
%
len
(
valid_requests
)
all_requests
.
append
(
valid_requests
[
unique_index
])
print
(
"All prompts/requests prepared.
\n
"
)
return
all_requests
###############################################################################
# PROFILING HELPERS
###############################################################################
async
def
send_profile_request
(
profile_text
,
item_count
,
session
=
None
):
"""Send a profile request and wait for completion."""
try
:
if
session
:
print
(
f
"Sending
{
profile_text
}
request via HTTP..."
)
# Determine the correct endpoint
base_url
=
HTTP_URL
.
rsplit
(
"/"
,
2
)[
0
]
# Remove /v1/score
if
profile_text
==
"START_PROFILE"
:
endpoint_url
=
f
"
{
base_url
}
/start_profile"
elif
profile_text
==
"STOP_PROFILE"
:
endpoint_url
=
f
"
{
base_url
}
/stop_profile"
else
:
print
(
f
"Unknown profile request:
{
profile_text
}
"
)
return
headers
=
{
"Content-Type"
:
"application/json"
}
async
with
session
.
post
(
endpoint_url
,
headers
=
headers
)
as
resp
:
resp_text
=
await
resp
.
text
()
if
resp
.
status
==
200
:
print
(
f
"
{
profile_text
}
request completed"
)
else
:
print
(
f
"
{
profile_text
}
request failed with status "
f
"
{
resp
.
status
}
:
{
resp_text
}
"
)
else
:
print
(
f
"Cannot send
{
profile_text
}
request - missing session"
)
except
Exception
as
e
:
print
(
f
"Error sending
{
profile_text
}
request:
{
e
}
"
)
###############################################################################
# HTTP CALLS
###############################################################################
def
build_http_request_json
(
score_data
):
"""Build HTTP request JSON for /v1/score endpoint.
Score API format:
{
"query": "Generated query text with SCORE_QUERY_TOKENS tokens",
"items": ["item1", "item2", ...], # Items to score with SCORE_ITEM_TOKENS each
"label_token_ids": [token_id1, token_id2], # Target token IDs
"model": "/path/to/model"
}
Args:
score_data: A dict containing query, items, label_token_ids, and model
"""
# score_data is already in the correct format from build_request
return
json
.
dumps
(
score_data
)
async
def
make_http_call
(
session
,
score_data
,
request_id
,
results_queue
):
"""HTTP call to /v1/score endpoint."""
try
:
start_time
=
asyncio
.
get_event_loop
().
time
()
request_json
=
build_http_request_json
(
score_data
)
headers
=
{
"Content-Type"
:
"application/json"
}
async
with
session
.
post
(
HTTP_URL
,
data
=
request_json
,
headers
=
headers
)
as
resp
:
resp_text
=
await
resp
.
text
()
if
resp
.
status
!=
200
:
print
(
f
"[HTTP] Request
{
request_id
}
failed with status "
f
"
{
resp
.
status
}
:
{
resp_text
}
"
)
completion_time
=
asyncio
.
get_event_loop
().
time
()
await
results_queue
.
put
((
request_id
,
0
,
False
,
completion_time
))
return
# Parse score API response
try
:
response_data
=
json
.
loads
(
resp_text
)
# Score API returns scores for each item
# For now, just verify we got a valid response
if
"scores"
in
response_data
or
"logprobs"
in
response_data
:
success
=
True
else
:
print
(
f
"[HTTP] Request
{
request_id
}
missing expected fields in response"
)
success
=
False
except
json
.
JSONDecodeError
:
print
(
f
"[HTTP] Request
{
request_id
}
failed to parse JSON response"
)
success
=
False
completion_time
=
asyncio
.
get_event_loop
().
time
()
elapsed_time
=
(
completion_time
-
start_time
)
*
1000
await
results_queue
.
put
((
request_id
,
elapsed_time
,
success
,
completion_time
))
except
Exception
as
e
:
print
(
f
"[HTTP] Error for request
{
request_id
}
:
{
e
}
"
)
completion_time
=
asyncio
.
get_event_loop
().
time
()
await
results_queue
.
put
((
request_id
,
0
,
False
,
completion_time
))
###############################################################################
# RESULTS
###############################################################################
async
def
process_results
(
results_queue
,
num_requests
,
send_duration
,
total_duration
,
rps
,
duration_secs
,
item_count
,
test_start_time
,
):
"""Processes results and groups them by minute intervals.
Returns a list of dictionaries, one for each minute."""
all_results
=
[]
# Collect all results
for
_
in
range
(
num_requests
):
result
=
await
results_queue
.
get
()
request_id
,
elapsed_time
,
success
,
completion_time
=
result
all_results
.
append
(
{
"request_id"
:
request_id
,
"elapsed_time"
:
elapsed_time
,
"success"
:
success
,
"completion_time"
:
completion_time
,
}
)
# Group results by minute intervals
minute_results
=
[]
num_minutes
=
int
(
duration_secs
//
60
)
+
(
1
if
duration_secs
%
60
>
0
else
0
)
for
minute
in
range
(
num_minutes
):
minute_start
=
test_start_time
+
(
minute
*
60
)
minute_end
=
test_start_time
+
((
minute
+
1
)
*
60
)
# Filter results that completed in this minute
minute_data
=
[
r
for
r
in
all_results
if
minute_start
<=
r
[
"completion_time"
]
<
minute_end
]
response_times
=
[
r
[
"elapsed_time"
]
for
r
in
minute_data
if
r
[
"success"
]]
successful_requests
=
len
([
r
for
r
in
minute_data
if
r
[
"success"
]])
failed_requests
=
len
([
r
for
r
in
minute_data
if
not
r
[
"success"
]])
avg_response_time
=
mean
(
response_times
)
if
response_times
else
0
# Calculate percentiles using numpy
if
response_times
:
p50
=
np
.
percentile
(
response_times
,
50
)
p90
=
np
.
percentile
(
response_times
,
90
)
p99
=
np
.
percentile
(
response_times
,
99
)
else
:
p50
=
p90
=
p99
=
0
minute_result
=
{
"test_duration_secs"
:
duration_secs
,
"minute_interval"
:
minute
+
1
,
"target_rps"
:
rps
,
"item_count"
:
item_count
,
"server_type"
:
SERVER_TYPE
,
"distribution"
:
DISTRIBUTION
,
"unique_requests"
:
NUM_UNIQUE_REQUESTS
,
"total_requests"
:
len
(
minute_data
),
"successful_requests"
:
successful_requests
,
"failed_requests"
:
failed_requests
,
"send_duration_secs"
:
send_duration
,
"total_duration_secs"
:
total_duration
,
"avg_response_time_ms"
:
avg_response_time
,
"p50_response_time_ms"
:
p50
,
"p90_response_time_ms"
:
p90
,
"p99_response_time_ms"
:
p99
,
}
minute_results
.
append
(
minute_result
)
print
(
f
"
\n
Minute
{
minute
+
1
}
Summary for RPS
{
rps
}
, "
f
"Duration
{
duration_secs
}
s, Item Count
{
item_count
}
:"
)
print
(
f
" Requests completed in minute:
{
len
(
minute_data
)
}
"
)
print
(
f
" Successful requests:
{
successful_requests
}
"
)
print
(
f
" Failed requests:
{
failed_requests
}
"
)
print
(
f
" Average response time:
{
avg_response_time
:.
2
f
}
ms"
)
print
(
f
" P50 response time:
{
p50
:.
2
f
}
ms"
)
print
(
f
" P90 response time:
{
p90
:.
2
f
}
ms"
)
print
(
f
" P99 response time:
{
p99
:.
2
f
}
ms"
)
# Also print overall summary
all_response_times
=
[
r
[
"elapsed_time"
]
for
r
in
all_results
if
r
[
"success"
]]
total_successful
=
len
([
r
for
r
in
all_results
if
r
[
"success"
]])
total_failed
=
len
([
r
for
r
in
all_results
if
not
r
[
"success"
]])
overall_avg
=
mean
(
all_response_times
)
if
all_response_times
else
0
if
all_response_times
:
overall_p50
=
np
.
percentile
(
all_response_times
,
50
)
overall_p90
=
np
.
percentile
(
all_response_times
,
90
)
overall_p99
=
np
.
percentile
(
all_response_times
,
99
)
else
:
overall_p50
=
overall_p90
=
overall_p99
=
0
print
(
f
"
\n
Overall Summary for RPS
{
rps
}
, Duration
{
duration_secs
}
s, "
f
"Item Count
{
item_count
}
:"
)
print
(
f
" Test duration:
{
duration_secs
}
seconds"
)
print
(
f
" Server type:
{
SERVER_TYPE
}
"
)
print
(
f
" HTTP mode: SINGLE_ITEM_SCORING"
)
print
(
f
" Target RPS:
{
rps
}
"
)
print
(
f
" Item count:
{
item_count
}
"
)
print
(
f
" Distribution:
{
DISTRIBUTION
}
"
)
print
(
f
" Unique requests generated:
{
NUM_UNIQUE_REQUESTS
}
"
)
print
(
f
" Total requests sent:
{
num_requests
}
"
)
print
(
f
" Successful requests:
{
total_successful
}
"
)
print
(
f
" Failed requests:
{
total_failed
}
"
)
print
(
f
" Time to send all requests:
{
send_duration
:.
2
f
}
seconds"
)
print
(
f
" Time for all requests to complete:
{
total_duration
:.
2
f
}
seconds"
)
print
(
f
" Average response time:
{
overall_avg
:.
2
f
}
ms"
)
print
(
f
" P50 response time:
{
overall_p50
:.
2
f
}
ms"
)
print
(
f
" P90 response time:
{
overall_p90
:.
2
f
}
ms"
)
print
(
f
" P99 response time:
{
overall_p99
:.
2
f
}
ms
\n
"
)
return
minute_results
###############################################################################
# MAIN
###############################################################################
async
def
run_benchmark
(
rps
,
duration_secs
,
item_count
):
"""Run a single benchmark with the given RPS value."""
num_requests
=
int
(
rps
*
duration_secs
)
print
(
f
"Starting benchmark with RPS=
{
rps
}
, Duration=
{
duration_secs
}
s, "
f
"Item Count=
{
item_count
}
, num_requests=
{
num_requests
}
"
)
print
(
f
"Server Type:
{
SERVER_TYPE
}
"
)
print
(
f
"HTTP Mode: SINGLE_ITEM_SCORING"
)
print
(
f
"Profiling Enabled:
{
PROFILE
}
"
)
# Build requests in parallel (unmeasured)
all_requests
=
prepare_all_requests_parallel
(
num_requests
,
item_count
)
results_queue
=
asyncio
.
Queue
()
tasks
=
[]
# Track timing for sending requests
send_start_time
=
asyncio
.
get_event_loop
().
time
()
# HTTP implementation (open source only supports HTTP with /v1/score API)
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
300
)
)
as
session
:
# Send START_PROFILE if profiling is enabled
if
PROFILE
:
await
send_profile_request
(
"START_PROFILE"
,
item_count
,
session
=
session
)
# Add progress bar for sending requests
with
tqdm
(
total
=
len
(
all_requests
),
desc
=
f
"Sending HTTP score requests at
{
rps
}
RPS"
,
unit
=
"req"
,
)
as
pbar
:
for
i
,
score_data
in
enumerate
(
all_requests
):
request_id
=
i
+
1
tasks
.
append
(
asyncio
.
create_task
(
make_http_call
(
session
,
score_data
,
request_id
,
results_queue
)
)
)
# Update progress bar
pbar
.
update
(
1
)
# Throttle based on distribution
if
i
<
len
(
all_requests
)
-
1
:
if
DISTRIBUTION
==
"CONSTANT"
:
interval
=
1
/
rps
await
asyncio
.
sleep
(
interval
)
elif
DISTRIBUTION
==
"POISSON"
:
# For Poisson process, inter-arrival times follow
# exponential distribution
interval
=
random
.
expovariate
(
rps
)
await
asyncio
.
sleep
(
interval
)
else
:
raise
ValueError
(
f
"Unknown distribution:
{
DISTRIBUTION
}
. "
f
"Use 'CONSTANT' or 'POISSON'."
)
send_end_time
=
asyncio
.
get_event_loop
().
time
()
send_duration
=
send_end_time
-
send_start_time
# Wait for all requests to complete with progress tracking
print
(
f
"Waiting for
{
len
(
tasks
)
}
HTTP score requests to complete..."
)
with
tqdm
(
total
=
len
(
tasks
),
desc
=
"Completing HTTP score requests"
,
unit
=
"req"
)
as
completion_pbar
:
completed_tasks
=
[]
for
task
in
asyncio
.
as_completed
(
tasks
):
await
task
completed_tasks
.
append
(
task
)
completion_pbar
.
update
(
1
)
# Send STOP_PROFILE if profiling is enabled
if
PROFILE
:
await
send_profile_request
(
"STOP_PROFILE"
,
item_count
,
session
=
session
)
completion_end_time
=
asyncio
.
get_event_loop
().
time
()
total_duration
=
completion_end_time
-
send_start_time
return
await
process_results
(
results_queue
,
num_requests
,
send_duration
,
total_duration
,
rps
,
duration_secs
,
item_count
,
send_start_time
,
)
async
def
main
():
"""Main function that runs benchmarks for all RPS values."""
total_combinations
=
(
len
(
DURATION_SECS_VALUES
)
*
len
(
RPS_VALUES
)
*
len
(
ITEM_COUNT_VALUES
)
)
print
(
f
"Running benchmarks for
{
len
(
DURATION_SECS_VALUES
)
}
duration "
f
"values,
{
len
(
RPS_VALUES
)
}
RPS values, and "
f
"
{
len
(
ITEM_COUNT_VALUES
)
}
item count values = "
f
"
{
total_combinations
}
total combinations"
)
print
(
f
"Server Type:
{
SERVER_TYPE
}
"
)
print
(
f
"HTTP Mode: SINGLE_ITEM_SCORING"
)
print
(
f
"Score API URL:
{
HTTP_URL
}
"
)
print
(
f
"Query tokens per request:
{
SCORE_QUERY_TOKENS
}
"
)
print
(
f
"Item tokens per item:
{
SCORE_ITEM_TOKENS
}
"
)
print
(
f
"Items per request (batch size):
{
ITEM_COUNT_VALUES
}
"
)
print
(
f
"Profiling Enabled:
{
PROFILE
}
"
)
print
(
f
"Duration values:
{
DURATION_SECS_VALUES
}
"
)
print
(
f
"RPS values:
{
RPS_VALUES
}
"
)
print
(
f
"Item count values:
{
ITEM_COUNT_VALUES
}
"
)
print
(
"="
*
80
)
all_results
=
[]
for
duration_secs
in
DURATION_SECS_VALUES
:
for
rps
in
RPS_VALUES
:
for
item_count
in
ITEM_COUNT_VALUES
:
result
=
await
run_benchmark
(
rps
,
duration_secs
,
item_count
)
all_results
.
extend
(
result
)
# Extend with minute results
# Print CSV header and results
print
(
"
\n
"
+
"="
*
80
)
print
(
"FINAL CSV RESULTS:"
)
print
(
"="
*
80
)
# CSV Header
headers
=
[
"test_duration_secs"
,
"minute_interval"
,
"target_rps"
,
"item_count"
,
"server_type"
,
"distribution"
,
"unique_requests"
,
"total_requests"
,
"successful_requests"
,
"failed_requests"
,
"send_duration_secs"
,
"total_duration_secs"
,
"avg_response_time_ms"
,
"p50_response_time_ms"
,
"p90_response_time_ms"
,
"p99_response_time_ms"
,
]
print
(
","
.
join
(
headers
))
# CSV Data
for
result
in
all_results
:
row
=
[
result
[
"test_duration_secs"
],
result
[
"minute_interval"
],
result
[
"target_rps"
],
result
[
"item_count"
],
result
[
"server_type"
],
result
[
"distribution"
],
result
[
"unique_requests"
],
result
[
"total_requests"
],
result
[
"successful_requests"
],
result
[
"failed_requests"
],
f
"
{
result
[
'send_duration_secs'
]:.
2
f
}
"
,
f
"
{
result
[
'total_duration_secs'
]:.
2
f
}
"
,
f
"
{
result
[
'avg_response_time_ms'
]:.
2
f
}
"
,
f
"
{
result
[
'p50_response_time_ms'
]:.
2
f
}
"
,
f
"
{
result
[
'p90_response_time_ms'
]:.
2
f
}
"
,
f
"
{
result
[
'p99_response_time_ms'
]:.
2
f
}
"
,
]
print
(
","
.
join
(
map
(
str
,
row
)))
if
__name__
==
"__main__"
:
asyncio
.
run
(
main
())
python/sglang/srt/managers/schedule_batch.py
View file @
a027a9b4
...
...
@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
return_hidden_states
:
bool
=
False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only
:
bool
=
False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index
:
int
=
0
...
...
@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
spec_algorithm
,
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
is_prefill_only
=
all
(
req
.
sampling_params
.
max_new_tokens
==
0
for
req
in
reqs
),
chunked_req
=
chunked_req
,
)
...
...
@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
is_extend_in_batch
=
self
.
is_extend_in_batch
,
is_prefill_only
=
self
.
is_prefill_only
,
)
def
_evict_tree_cache_if_needed
(
self
,
num_tokens
:
int
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
a027a9b4
...
...
@@ -1466,8 +1466,9 @@ class Scheduler(
if
self
.
last_batch
.
batch_size
()
<
last_bs
:
self
.
running_batch
.
batch_is_full
=
False
# Merge the new batch into the running batch
if
not
self
.
last_batch
.
is_empty
():
# Merge the new batch into the running batch.
# For prefill-only batch, we can avoid going through decoding step.
if
not
self
.
last_batch
.
is_empty
()
and
not
self
.
last_batch
.
is_prefill_only
:
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
self
.
last_batch
else
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
a027a9b4
...
...
@@ -699,7 +699,7 @@ class TokenizerManager:
# Process all requests
tokenized_objs
=
[]
for
i
,
req
in
enumerate
(
requests
):
self
.
_validate_
token_len
(
obj
[
i
],
input_ids_list
[
i
])
self
.
_validate_
one_request
(
obj
[
i
],
input_ids_list
[
i
])
tokenized_objs
.
append
(
self
.
_create_tokenized_object
(
req
,
req
.
text
,
input_ids_list
[
i
],
None
,
None
...
...
@@ -1892,6 +1892,13 @@ class TokenizerManager:
f
"Token ID
{
token_id
}
is out of vocabulary (vocab size:
{
vocab_size
}
)"
)
batch_request
=
GenerateReqInput
(
token_ids_logprob
=
label_token_ids
,
return_logprob
=
True
,
stream
=
False
,
sampling_params
=
{
"max_new_tokens"
:
0
},
)
# Handle string or tokenized query/items
if
isinstance
(
query
,
str
)
and
(
isinstance
(
items
,
str
)
...
...
@@ -1903,13 +1910,9 @@ class TokenizerManager:
prompts
=
[
f
"
{
item
}{
query
}
"
for
item
in
items_list
]
else
:
prompts
=
[
f
"
{
query
}{
item
}
"
for
item
in
items_list
]
batch_request
=
GenerateReqInput
(
text
=
prompts
,
return_logprob
=
True
,
token_ids_logprob
=
label_token_ids
,
stream
=
False
,
sampling_params
=
{
"max_new_tokens"
:
1
},
)
batch_request
.
text
=
prompts
elif
(
isinstance
(
query
,
list
)
and
isinstance
(
items
,
list
)
...
...
@@ -1921,13 +1924,8 @@ class TokenizerManager:
input_ids_list
=
[
item
+
query
for
item
in
items
]
else
:
input_ids_list
=
[
query
+
item
for
item
in
items
]
batch_request
=
GenerateReqInput
(
input_ids
=
input_ids_list
,
return_logprob
=
True
,
token_ids_logprob
=
label_token_ids
,
stream
=
False
,
sampling_params
=
{
"max_new_tokens"
:
1
},
)
batch_request
.
input_ids
=
input_ids_list
else
:
raise
ValueError
(
"Invalid combination of query/items types for score_request."
...
...
@@ -1939,9 +1937,20 @@ class TokenizerManager:
for
result
in
results
:
# Get logprobs for each token
logprobs
=
{}
for
logprob
,
token_id
,
_
in
result
[
"meta_info"
].
get
(
"output_token_ids_logprobs"
,
[]
)[
0
]:
# For scoring requests, we read from output_token_ids_logprobs since we want
# the logprobs for specific tokens mentioned in the label_token_ids at
# the next position after the last token in the prompt
output_logprobs
=
result
[
"meta_info"
].
get
(
"output_token_ids_logprobs"
,
[])
# Throw an error here if output_logprobs is None
if
output_logprobs
is
None
:
raise
RuntimeError
(
f
"output_logprobs is None for request
{
result
[
'meta_info'
].
get
(
'id'
,
'<unknown>'
)
}
. "
"This usually indicates a problem with the scoring request or the backend output."
)
for
logprob
,
token_id
,
_
in
output_logprobs
[
0
]:
if
token_id
in
label_token_ids
:
logprobs
[
token_id
]
=
logprob
...
...
test/srt/test_score_api.py
View file @
a027a9b4
...
...
@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
1.0
,
sum
(
score_list
),
6
,
"Scores should sum to 1"
)
def
test_score_request_construction
(
self
):
"""Test that scoring requests are constructed to avoid decode phase."""
from
unittest.mock
import
patch
# Capture the internal request to verify optimization
captured_requests
=
[]
original_gen
=
self
.
engine
.
tokenizer_manager
.
generate_request
async
def
mock_generate_request
(
req
,
request
=
None
):
captured_requests
.
append
(
req
)
async
for
result
in
original_gen
(
req
,
request
):
yield
result
# Patch the generate_request method
with
patch
.
object
(
self
.
engine
.
tokenizer_manager
,
"generate_request"
,
side_effect
=
mock_generate_request
,
):
# Run a scoring request
query
=
"What is the capital of"
items
=
[
"France"
,
"Germany"
]
label_token_ids
=
[
1
,
2
,
3
]
scores
=
self
.
engine
.
score
(
query
=
query
,
items
=
items
,
label_token_ids
=
label_token_ids
,
apply_softmax
=
True
,
)
# Verify we got results
self
.
assertEqual
(
len
(
scores
),
len
(
items
))
# Verify the captured request has decode-avoiding properties
self
.
assertEqual
(
len
(
captured_requests
),
1
)
request
=
captured_requests
[
0
]
# Key assertions for decode phase avoidance:
# 1. max_new_tokens should be 0 (prevents token generation)
# Handle both single and batch request cases
if
isinstance
(
request
.
sampling_params
,
dict
):
max_new_tokens
=
request
.
sampling_params
.
get
(
"max_new_tokens"
,
0
)
elif
isinstance
(
request
.
sampling_params
,
list
):
# For batch requests, check the first item
max_new_tokens
=
request
.
sampling_params
[
0
].
get
(
"max_new_tokens"
,
0
)
else
:
max_new_tokens
=
getattr
(
request
.
sampling_params
,
"max_new_tokens"
,
0
)
self
.
assertEqual
(
max_new_tokens
,
0
,
"max_new_tokens should be 0 to avoid decode phase"
)
# 2. Should have token_ids_logprob for scoring
# Handle both single and batch request cases
if
(
isinstance
(
request
.
token_ids_logprob
,
list
)
and
len
(
request
.
token_ids_logprob
)
>
0
and
isinstance
(
request
.
token_ids_logprob
[
0
],
list
)
):
# Batch case: token_ids_logprob is a list of lists
# Each item in the batch should have the same label_token_ids
for
item_token_ids
in
request
.
token_ids_logprob
:
self
.
assertEqual
(
item_token_ids
,
label_token_ids
,
"Each batch item should have label_token_ids for scoring"
,
)
else
:
# Single request case
self
.
assertEqual
(
request
.
token_ids_logprob
,
label_token_ids
,
"Should have label_token_ids for scoring"
,
)
# 3. Should request logprobs but not stream
self
.
assertTrue
(
request
.
return_logprob
,
"Should request logprobs for scoring"
)
self
.
assertFalse
(
request
.
stream
,
"Scoring requests should not stream"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_tokenizer_batch_encode.py
0 → 100644
View file @
a027a9b4
"""
Unit tests for enable_tokenizer_batch_encode feature.
This tests the batch tokenization functionality which allows processing
multiple text inputs in a single batch for improved performance.
Usage:
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path
"""
import
asyncio
import
unittest
from
typing
import
List
from
unittest.mock
import
AsyncMock
,
Mock
,
call
,
patch
from
sglang.srt.managers.io_struct
import
GenerateReqInput
,
TokenizedGenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class
TestTokenizerBatchEncode
(
unittest
.
TestCase
):
"""Test cases for tokenizer batch encoding validation and setup."""
def
setUp
(
self
):
"""Set up test fixtures."""
self
.
server_args
=
ServerArgs
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
enable_tokenizer_batch_encode
=
True
,
)
self
.
port_args
=
PortArgs
.
init_new
(
self
.
server_args
)
with
patch
(
"zmq.asyncio.Context"
),
patch
(
"sglang.srt.utils.get_zmq_socket"
),
patch
(
"sglang.srt.hf_transformers_utils.get_tokenizer"
)
as
mock_tokenizer
:
mock_tokenizer
.
return_value
=
Mock
(
vocab_size
=
32000
)
self
.
tokenizer_manager
=
TokenizerManager
(
self
.
server_args
,
self
.
port_args
)
def
test_batch_encode_enabled
(
self
):
"""Test that batch encoding is enabled when configured."""
self
.
assertTrue
(
self
.
server_args
.
enable_tokenizer_batch_encode
)
def
test_batch_encode_disabled
(
self
):
"""Test that batch encoding can be disabled."""
server_args_disabled
=
ServerArgs
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
enable_tokenizer_batch_encode
=
False
,
)
self
.
assertFalse
(
server_args_disabled
.
enable_tokenizer_batch_encode
)
def
test_multimodal_input_validation
(
self
):
"""Test that multimodal inputs are rejected in batch mode."""
req
=
GenerateReqInput
(
text
=
"test"
,
image_data
=
[
"dummy"
])
req
.
contains_mm_input
=
Mock
(
return_value
=
True
)
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
lambda
self
,
i
:
req
self
.
tokenizer_manager
.
is_generation
=
True
with
self
.
assertRaises
(
ValueError
)
as
cm
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
1
,
batch_obj
)
self
.
assertIn
(
"multimodal"
,
str
(
cm
.
exception
))
def
test_pretokenized_input_validation
(
self
):
"""Test that pre-tokenized inputs are rejected in batch mode."""
req
=
GenerateReqInput
(
input_ids
=
[
1
,
2
,
3
])
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
lambda
self
,
i
:
req
with
self
.
assertRaises
(
ValueError
)
as
cm
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
1
,
batch_obj
)
self
.
assertIn
(
"pre-tokenized"
,
str
(
cm
.
exception
))
def
test_input_embeds_validation
(
self
):
"""Test that input embeds are rejected in batch mode."""
req
=
GenerateReqInput
(
input_embeds
=
[
0.1
,
0.2
])
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
lambda
self
,
i
:
req
with
self
.
assertRaises
(
ValueError
)
as
cm
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
1
,
batch_obj
)
self
.
assertIn
(
"input_embeds"
,
str
(
cm
.
exception
))
def
test_valid_text_only_requests_pass_validation
(
self
):
"""Test that valid text-only requests pass validation."""
# Create valid requests (text-only)
requests
=
[]
for
i
in
range
(
3
):
req
=
GenerateReqInput
(
text
=
f
"test text
{
i
}
"
)
req
.
contains_mm_input
=
Mock
(
return_value
=
False
)
requests
.
append
(
req
)
batch_obj
=
Mock
()
batch_obj
.
__getitem__
=
Mock
(
side_effect
=
lambda
i
:
requests
[
i
])
# Should not raise any exception
try
:
self
.
tokenizer_manager
.
_validate_batch_tokenization_constraints
(
3
,
batch_obj
)
except
Exception
as
e
:
self
.
fail
(
f
"Validation failed for valid text-only requests:
{
e
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment