Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fba8eccd
Unverified
Commit
fba8eccd
authored
May 12, 2025
by
Lianmin Zheng
Committed by
GitHub
May 12, 2025
Browse files
Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by:
SangBin Cho
<
rkooo567@gmail.com
>
parent
7d3a3d45
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
265 additions
and
108 deletions
+265
-108
python/sglang/bench_offline_throughput.py
python/sglang/bench_offline_throughput.py
+3
-1
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-2
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+143
-15
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+7
-5
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+6
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+1
-3
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+1
-1
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+4
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+18
-10
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+7
-11
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+12
-9
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+16
-8
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-9
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+15
-12
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+7
-7
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
No files found.
python/sglang/bench_offline_throughput.py
View file @
fba8eccd
...
...
@@ -259,7 +259,9 @@ def throughput_test_once(
measurement_results
[
"total_input_tokens"
]
+
measurement_results
[
"total_output_tokens"
]
)
/
latency
measurement_results
[
"last_gen_throughput"
]
=
server_info
[
"last_gen_throughput"
]
measurement_results
[
"last_gen_throughput"
]
=
server_info
[
"internal_states"
][
0
][
"last_gen_throughput"
]
return
measurement_results
...
...
python/sglang/bench_one_batch.py
View file @
fba8eccd
...
...
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
...
...
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/bench_one_batch_server.py
View file @
fba8eccd
...
...
@@ -25,6 +25,7 @@ import requests
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
@
dataclasses
.
dataclass
...
...
@@ -33,9 +34,13 @@ class BenchArgs:
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
temperature
:
float
=
0.0
return_logprob
:
bool
=
False
input_len_step_percentage
:
float
=
0.0
result_filename
:
str
=
"result.jsonl"
base_url
:
str
=
""
skip_warmup
:
bool
=
False
show_report
:
bool
=
False
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
...
@@ -49,11 +54,19 @@ class BenchArgs:
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--input-len-step-percentage"
,
type
=
float
,
default
=
BenchArgs
.
input_len_step_percentage
,
)
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
)
parser
.
add_argument
(
"--base-url"
,
type
=
str
,
default
=
BenchArgs
.
base_url
)
parser
.
add_argument
(
"--skip-warmup"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--show-report"
,
action
=
"store_true"
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
@@ -99,36 +112,89 @@ def run_one_case(
batch_size
:
int
,
input_len
:
int
,
output_len
:
int
,
temperature
:
float
,
return_logprob
:
bool
,
input_len_step_percentage
:
float
,
run_name
:
str
,
result_filename
:
str
,
):
requests
.
post
(
url
+
"/flush_cache"
)
input_lens
=
[
int
(
input_len
*
(
1
+
(
i
-
(
batch_size
-
1
)
/
2
)
*
input_len_step_percentage
))
for
i
in
range
(
batch_size
)
]
input_ids
=
[
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
input_len
,))]
for
_
in
range
(
batch_size
)
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
input_len
s
[
i
]
,))]
for
i
in
range
(
batch_size
)
]
use_structured_outputs
=
False
if
use_structured_outputs
:
texts
=
[]
for
_
in
range
(
batch_size
):
texts
.
append
(
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.
\n
"
*
50
+
"Assistant:"
)
json_schema
=
"$$ANY$$"
else
:
json_schema
=
None
tic
=
time
.
time
()
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
# "text": texts,
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
},
stream
=
True
,
)
latency
=
time
.
time
()
-
tic
_
=
response
.
json
()
output_throughput
=
batch_size
*
output_len
/
latency
# The TTFT of the last request in the batch
ttft
=
0.0
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
"error"
in
data
:
raise
RuntimeError
(
f
"Request has failed.
{
data
}
."
)
assert
(
data
[
"meta_info"
][
"finish_reason"
]
is
None
or
data
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"length"
)
if
data
[
"meta_info"
][
"completion_tokens"
]
==
1
:
ttft
=
time
.
time
()
-
tic
latency
=
time
.
time
()
-
tic
input_throughput
=
batch_size
*
input_len
/
ttft
output_throughput
=
batch_size
*
output_len
/
(
latency
-
ttft
)
overall_throughput
=
batch_size
*
(
input_len
+
output_len
)
/
latency
server_info
=
requests
.
get
(
url
+
"/get_server_info"
).
json
()
acc_length
=
server_info
[
"internal_states"
][
0
].
get
(
"avg_spec_accept_length"
,
None
)
last_gen_throughput
=
server_info
[
"internal_states"
][
0
][
"last_gen_throughput"
]
print
(
f
"batch size:
{
batch_size
}
"
)
print
(
f
"input_len:
{
input_len
}
"
)
print
(
f
"output_len:
{
output_len
}
"
)
print
(
f
"latency:
{
latency
:.
2
f
}
s"
)
print
(
f
"output throughput:
{
output_throughput
:.
2
f
}
token/s"
)
print
(
f
"(input + output) throughput:
{
overall_throughput
:.
2
f
}
token/s"
)
print
(
f
"ttft:
{
ttft
:.
2
f
}
s"
)
print
(
f
"Last generation throughput:
{
last_gen_throughput
:.
2
f
}
tok/s"
)
print
(
f
"Input throughput:
{
input_throughput
:.
2
f
}
tok/s"
)
if
output_len
!=
1
:
print
(
f
"output throughput:
{
output_throughput
:.
2
f
}
tok/s"
)
if
result_filename
:
with
open
(
result_filename
,
"a"
)
as
fout
:
...
...
@@ -140,9 +206,21 @@ def run_one_case(
"latency"
:
round
(
latency
,
4
),
"output_throughput"
:
round
(
output_throughput
,
2
),
"overall_throughput"
:
round
(
overall_throughput
,
2
),
"last_gen_throughput"
:
round
(
last_gen_throughput
,
2
),
}
fout
.
write
(
json
.
dumps
(
res
)
+
"
\n
"
)
return
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
)
def
run_benchmark
(
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
):
if
bench_args
.
base_url
:
...
...
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
# warmup
if
not
bench_args
.
skip_warmup
:
print
(
"="
*
8
+
" Warmup Begin "
+
"="
*
8
)
run_one_case
(
base_url
,
batch_size
=
16
,
input_len
=
1024
,
output_len
=
16
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
""
,
result_filename
=
""
,
)
print
(
"="
*
8
+
" Warmup End "
+
"="
*
8
+
"
\n
"
)
# benchmark
result
=
[]
try
:
for
bs
,
il
,
ol
in
itertools
.
product
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
):
run_one_case
(
base_url
,
bs
,
il
,
ol
,
bench_args
.
run_name
,
bench_args
.
result_filename
,
result
.
append
(
run_one_case
(
base_url
,
bs
,
il
,
ol
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
)
)
finally
:
if
proc
:
...
...
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
if
not
bench_args
.
show_report
:
return
summary
=
" | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |
\n
"
summary
+=
"| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |
\n
"
for
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
)
in
result
:
hourly_cost
=
2
*
server_args
.
tp_size
# $2/hour for one H100
input_util
=
0.7
accept_length
=
round
(
acc_length
,
2
)
if
acc_length
is
not
None
else
"n/a"
line
=
(
f
"|
{
batch_size
}
| "
f
"
{
latency
:.
2
f
}
| "
f
"
{
input_throughput
:.
2
f
}
| "
f
"
{
output_throughput
:.
2
f
}
| "
f
"
{
accept_length
}
| "
f
"
{
1
/
(
output_throughput
/
batch_size
)
*
1000
:.
2
f
}
| "
f
"
{
1e6
/
(
input_throughput
*
input_util
)
/
3600
*
hourly_cost
:.
2
f
}
| "
f
"
{
1e6
/
output_throughput
/
3600
*
hourly_cost
:.
2
f
}
|
\n
"
)
summary
+=
line
# print metrics table
print
(
summary
)
if
is_in_ci
():
write_github_step_summary
(
f
"### Test Nightly Benchmark (bench_one_batch)
\n
{
summary
}
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
...
...
python/sglang/bench_serving.py
View file @
fba8eccd
...
...
@@ -1103,7 +1103,7 @@ async def benchmark(
lora_names
:
List
[
str
],
extra_request_body
:
Dict
[
str
,
Any
],
profile
:
bool
,
pd_sep
e
rated
:
bool
=
False
,
pd_sep
a
rated
:
bool
=
False
,
flush_cache
:
bool
=
False
,
warmup_requests
:
int
=
1
,
):
...
...
@@ -1239,12 +1239,14 @@ async def benchmark(
if
"sglang"
in
backend
:
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
if
pd_sep
e
rated
:
accept_length
=
server_info
.
json
()[
"decode"
][
0
].
get
(
if
pd_sep
a
rated
:
accept_length
=
server_info
.
json
()[
"decode"
][
0
]
[
"internal_states"
][
0
]
.
get
(
"avg_spec_accept_length"
,
None
)
else
:
accept_length
=
server_info
.
json
().
get
(
"avg_spec_accept_length"
,
None
)
accept_length
=
server_info
.
json
()[
"internal_states"
][
0
].
get
(
"avg_spec_accept_length"
,
None
)
else
:
accept_length
=
None
...
...
@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
lora_names
=
args
.
lora_name
,
extra_request_body
=
extra_request_body
,
profile
=
args
.
profile
,
pd_sep
e
rated
=
args
.
pd_sep
e
rated
,
pd_sep
a
rated
=
args
.
pd_sep
a
rated
,
flush_cache
=
args
.
flush_cache
,
)
)
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
fba8eccd
...
...
@@ -37,6 +37,12 @@ class BaseGrammarObject:
"""
raise
NotImplementedError
()
def
rollback
(
self
,
k
:
int
):
raise
NotImplementedError
()
def
is_terminated
(
self
):
raise
NotImplementedError
()
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
fba8eccd
...
...
@@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
bid
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
bid
,
)
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
# wait
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
_
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
...
...
python/sglang/srt/entrypoints/engine.py
View file @
fba8eccd
...
...
@@ -330,7 +330,7 @@ class Engine(EngineBase):
return
{
**
dataclasses
.
asdict
(
self
.
tokenizer_manager
.
server_args
),
**
self
.
scheduler_info
,
**
internal_states
,
"internal_states"
:
internal_states
,
"version"
:
__version__
,
}
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
fba8eccd
...
...
@@ -222,7 +222,7 @@ async def get_server_info():
return
{
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
_global_state
.
scheduler_info
,
**
internal_states
,
"internal_states"
:
internal_states
,
"version"
:
__version__
,
}
...
...
python/sglang/srt/layers/attention/utils.py
View file @
fba8eccd
...
...
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
# index into req_to_token_ptr needs to be int64
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
).
to
(
tl
.
int64
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
req_to_token_ptr
...
...
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_pages_loop
):
# index into req_to_token_ptr needs to be int64
paged_offset
=
(
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
+
i
*
NUM_PAGE_PER_BLOCK
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
.
to
(
tl
.
int64
)
+
i
*
NUM_PAGE_PER_BLOCK
)
*
PAGED_SIZE
paged_offset_out
=
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
+
i
*
NUM_PAGE_PER_BLOCK
...
...
python/sglang/srt/managers/scheduler.py
View file @
fba8eccd
...
...
@@ -160,6 +160,7 @@ class GenerationBatchResult:
extend_input_len_per_req
:
List
[
int
]
extend_logprob_start_len_per_req
:
List
[
int
]
bid
:
int
can_run_cuda_graph
:
bool
@
dataclass
...
...
@@ -323,13 +324,14 @@ class Scheduler(
set_random_seed
(
self
.
random_seed
)
# Print debug info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"chunked_prefill_size=
{
server_args
.
chunked_prefill_size
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
if
tp_rank
==
0
:
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"chunked_prefill_size=
{
server_args
.
chunked_prefill_size
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init memory pool and cache
self
.
init_memory_pool_and_cache
()
...
...
@@ -752,6 +754,7 @@ class Scheduler(
extend_input_len_per_req
=
None
,
extend_logprob_start_len_per_req
=
None
,
bid
=
bids
[
next_mb_id
],
can_run_cuda_graph
=
result
.
can_run_cuda_graph
,
)
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
last_mbs
[
next_mb_id
]
=
mbs
[
next_mb_id
]
...
...
@@ -1159,7 +1162,9 @@ class Scheduler(
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
def
log_decode_stats
(
self
,
running_batch
=
None
):
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
):
batch
=
running_batch
or
self
.
running_batch
gap_latency
=
time
.
time
()
-
self
.
last_decode_stats_tic
...
...
@@ -1199,6 +1204,7 @@ class Scheduler(
msg
+=
f
"pre-allocated usage:
{
self
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
...
...
@@ -1524,11 +1530,11 @@ class Scheduler(
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
else
:
pp_hidden_states_proxy_tensors
,
_
=
(
pp_hidden_states_proxy_tensors
,
_
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
bid
=
model_worker_batch
.
bid
...
...
@@ -1538,6 +1544,7 @@ class Scheduler(
next_token_ids
,
bid
,
num_accepted_tokens
,
can_run_cuda_graph
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
spec_num_total_accepted_tokens
+=
(
num_accepted_tokens
+
batch
.
batch_size
()
...
...
@@ -1571,6 +1578,7 @@ class Scheduler(
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
bid
=
bid
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
else
:
# embedding or reward model
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
fba8eccd
...
...
@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin:
next_token_ids
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
bid
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
bid
,
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
,
)
logits_output
,
next_token_ids
,
_
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
)
else
:
# Move next_token_ids and logprobs to cpu
...
...
@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin:
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
logits_output
,
next_token_ids
,
bid
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
result
.
can_run_cuda_graph
,
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
...
...
@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin:
self
.
attn_tp_rank
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
):
self
.
log_decode_stats
(
running_batch
=
batch
)
self
.
log_decode_stats
(
can_run_cuda_graph
,
running_batch
=
batch
)
def
add_input_logprob_return_values
(
self
:
Scheduler
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fba8eccd
...
...
@@ -923,12 +923,13 @@ class TokenizerManager:
):
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
async
def
get_internal_state
(
self
)
->
Dict
[
Any
,
Any
]:
async
def
get_internal_state
(
self
)
->
List
[
Dict
[
Any
,
Any
]
]
:
req
=
GetInternalStateReq
()
res
:
List
[
GetInternalStateReqOutput
]
=
(
res
ponses
:
List
[
GetInternalStateReqOutput
]
=
(
await
self
.
get_internal_state_communicator
(
req
)
)
return
res
[
0
].
internal_state
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
def
get_log_request_metadata
(
self
):
max_length
=
None
...
...
python/sglang/srt/managers/tp_worker.py
View file @
fba8eccd
...
...
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.distributed
import
get_pp_group
,
get_tp_group
,
get_world_group
from
sglang.srt.distributed
import
get_pp_group
,
get_world_group
from
sglang.srt.hf_transformers_utils
import
(
get_processor
,
get_tokenizer
,
...
...
@@ -183,8 +183,11 @@ class TpModelWorker:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
pp_proxy_tensors
=
None
...
...
@@ -196,11 +199,11 @@ class TpModelWorker:
)
if
self
.
pp_group
.
is_last_rank
:
logits_output
=
self
.
model_runner
.
forward
(
logits_output
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
if
launch_done
is
not
None
:
launch_done
.
set
()
if
skip_sample
:
next_token_ids
=
None
...
...
@@ -209,17 +212,17 @@ class TpModelWorker:
logits_output
,
model_worker_batch
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
else
:
pp_proxy_tensors
=
self
.
model_runner
.
forward
(
pp_proxy_tensors
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
return
pp_proxy_tensors
.
tensors
,
None
return
pp_proxy_tensors
.
tensors
,
None
,
can_run_cuda_graph
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
return
embeddings
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
fba8eccd
...
...
@@ -18,7 +18,7 @@ import logging
import
signal
import
threading
from
queue
import
Queue
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
psutil
import
torch
...
...
@@ -145,8 +145,10 @@ class TpModelWorkerClient:
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# Run forward
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
)
)
# Update the future token ids map
...
...
@@ -171,14 +173,18 @@ class TpModelWorkerClient:
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
copy_done
.
record
()
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
self
.
output_queue
.
put
(
(
copy_done
,
logits_output
,
next_token_ids
,
can_run_cuda_graph
)
)
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
copy_done
,
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
output_queue
.
get
()
)
if
launch_done
is
not
None
:
launch_done
.
wait
()
...
...
@@ -193,9 +199,11 @@ class TpModelWorkerClient:
logits_output
.
input_token_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
)
->
Tuple
[
None
,
torch
.
Tensor
,
bool
]:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info
=
model_worker_batch
.
sampling_info
sampling_info
.
update_penalties
()
...
...
@@ -223,7 +231,7 @@ class TpModelWorkerClient:
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
return
None
,
future_next_token_ids
return
None
,
future_next_token_ids
,
False
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
fba8eccd
...
...
@@ -19,7 +19,7 @@ import bisect
import
inspect
import
os
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
tqdm
...
...
@@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_device_memory_capacity
,
is_hip
,
rank0_log
,
)
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
_is_hip
=
is_hip
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
...
...
@@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
)
gpu_mem
=
get_device_memory_capacity
()
# Batch size of each rank will not become so large when DP is on
if
gpu_mem
is
not
None
and
gpu_mem
>
96
*
1024
:
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
...
...
@@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
capture_bs
+=
list
(
range
(
max
(
capture_bs
),
server_args
.
cuda_graph_max_bs
+
1
,
16
)
)
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
compile_bs
=
(
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
if
server_args
.
enable_torch_compile
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fba8eccd
...
...
@@ -1085,32 +1085,33 @@ class ModelRunner:
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]
,
bool
]
:
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
)
if
can_run_cuda_graph
:
ret
urn
self
.
cuda_graph_runner
.
replay
(
ret
=
self
.
cuda_graph_runner
.
replay
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
elif
forward_batch
.
forward_mode
.
is_decode
():
ret
=
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
elif
forward_batch
.
forward_mode
.
is_extend
():
ret
urn
self
.
forward_extend
(
ret
=
self
.
forward_extend
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
elif
forward_batch
.
forward_mode
.
is_idle
():
ret
urn
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
ret
=
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
return
ret
,
can_run_cuda_graph
def
_preprocess_logits
(
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
):
...
...
python/sglang/srt/server_args.py
View file @
fba8eccd
...
...
@@ -1086,7 +1086,7 @@ class ServerArgs:
"--cuda-graph-max-bs"
,
type
=
int
,
default
=
ServerArgs
.
cuda_graph_max_bs
,
help
=
"Set the maximum batch size for cuda graph."
,
help
=
"Set the maximum batch size for cuda graph.
It will extend the cuda graph capture batch size to this value.
"
,
)
parser
.
add_argument
(
"--cuda-graph-bs"
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
fba8eccd
...
...
@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
if
batch
.
forward_mode
.
is_decode
():
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
logits_output
,
verify_output
,
model_worker_batch
,
can_run_cuda_graph
=
(
self
.
verify
(
batch
,
spec_info
)
)
# If it is None, it means all requests are finished
...
...
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
can_run_cuda_graph
,
)
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
logits_output
,
next_token_ids
,
_
=
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
return
logits_output
,
next_token_ids
,
bid
,
0
return
logits_output
,
next_token_ids
,
bid
,
0
,
False
def
forward_target_extend
(
self
,
batch
:
ScheduleBatch
...
...
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
...
...
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
spec_info
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
logits_output
,
_
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
...
...
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
if
batch
.
return_logprob
:
self
.
add_logprob_values
(
batch
,
res
,
logits_output
)
return
logits_output
,
res
,
model_worker_batch
return
logits_output
,
res
,
model_worker_batch
,
can_run_cuda_graph
def
add_logprob_values
(
self
,
...
...
@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch
,
self
.
draft_model_runner
)
forward_batch
.
return_logprob
=
False
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
...
...
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
)
# Run
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
...
...
python/sglang/test/test_utils.py
View file @
fba8eccd
...
...
@@ -395,12 +395,12 @@ def popen_launch_server(
other_args
:
list
[
str
]
=
(),
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
pd_sep
e
rated
:
bool
=
False
,
pd_sep
a
rated
:
bool
=
False
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
if
pd_sep
e
rated
:
if
pd_sep
a
rated
:
command
=
"sglang.launch_pd_server"
else
:
command
=
"sglang.launch_server"
...
...
@@ -414,7 +414,7 @@ def popen_launch_server(
*
[
str
(
x
)
for
x
in
other_args
],
]
if
pd_sep
e
rated
:
if
pd_sep
a
rated
:
command
.
extend
(
[
"--lb-host"
,
...
...
@@ -656,7 +656,7 @@ def get_benchmark_args(
disable_stream
=
False
,
disable_ignore_eos
=
False
,
seed
:
int
=
0
,
pd_sep
e
rated
:
bool
=
False
,
pd_sep
a
rated
:
bool
=
False
,
):
return
SimpleNamespace
(
backend
=
"sglang"
,
...
...
@@ -686,7 +686,7 @@ def get_benchmark_args(
profile
=
None
,
lora_name
=
None
,
prompt_suffix
=
""
,
pd_sep
e
rated
=
pd_sep
e
rated
,
pd_sep
a
rated
=
pd_sep
a
rated
,
)
...
...
@@ -750,7 +750,7 @@ def run_bench_serving_multi(
other_server_args
,
benchmark_args
,
need_warmup
=
False
,
pd_sep
e
rated
=
False
,
pd_sep
a
rated
=
False
,
):
# Launch the server
process
=
popen_launch_server
(
...
...
@@ -758,7 +758,7 @@ def run_bench_serving_multi(
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_server_args
,
pd_sep
e
rated
=
pd_sep
e
rated
,
pd_sep
a
rated
=
pd_sep
a
rated
,
)
# run benchmark for all
...
...
test/srt/run_suite.py
View file @
fba8eccd
...
...
@@ -101,8 +101,8 @@ suites = {
# TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
# TestFile("test_disaggregation.py", 90),
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_disaggregation.py"
,
90
),
TestFile
(
"test_full_deepseek_v3.py"
,
250
),
TestFile
(
"test_pp_single_node.py"
,
150
),
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment