Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fba8eccd
Unverified
Commit
fba8eccd
authored
May 12, 2025
by
Lianmin Zheng
Committed by
GitHub
May 12, 2025
Browse files
Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by:
SangBin Cho
<
rkooo567@gmail.com
>
parent
7d3a3d45
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
265 additions
and
108 deletions
+265
-108
python/sglang/bench_offline_throughput.py
python/sglang/bench_offline_throughput.py
+3
-1
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-2
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+143
-15
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+7
-5
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+6
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+1
-3
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+1
-1
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+4
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+18
-10
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+7
-11
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+4
-3
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+12
-9
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+16
-8
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-9
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+15
-12
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+7
-7
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
No files found.
python/sglang/bench_offline_throughput.py
View file @
fba8eccd
...
@@ -259,7 +259,9 @@ def throughput_test_once(
...
@@ -259,7 +259,9 @@ def throughput_test_once(
measurement_results
[
"total_input_tokens"
]
measurement_results
[
"total_input_tokens"
]
+
measurement_results
[
"total_output_tokens"
]
+
measurement_results
[
"total_output_tokens"
]
)
/
latency
)
/
latency
measurement_results
[
"last_gen_throughput"
]
=
server_info
[
"last_gen_throughput"
]
measurement_results
[
"last_gen_throughput"
]
=
server_info
[
"internal_states"
][
0
][
"last_gen_throughput"
]
return
measurement_results
return
measurement_results
...
...
python/sglang/bench_one_batch.py
View file @
fba8eccd
...
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
...
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
...
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
...
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
_maybe_prepare_dp_attn_batch
(
batch
,
model_runner
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
model_runner
)
logits_output
=
model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
model_runner
.
forward
(
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
next_token_ids
=
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
...
...
python/sglang/bench_one_batch_server.py
View file @
fba8eccd
...
@@ -25,6 +25,7 @@ import requests
...
@@ -25,6 +25,7 @@ import requests
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
is_in_ci
,
write_github_step_summary
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -33,9 +34,13 @@ class BenchArgs:
...
@@ -33,9 +34,13 @@ class BenchArgs:
batch_size
:
Tuple
[
int
]
=
(
1
,)
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
temperature
:
float
=
0.0
return_logprob
:
bool
=
False
input_len_step_percentage
:
float
=
0.0
result_filename
:
str
=
"result.jsonl"
result_filename
:
str
=
"result.jsonl"
base_url
:
str
=
""
base_url
:
str
=
""
skip_warmup
:
bool
=
False
skip_warmup
:
bool
=
False
show_report
:
bool
=
False
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -49,11 +54,19 @@ class BenchArgs:
...
@@ -49,11 +54,19 @@ class BenchArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
"--output-len"
,
type
=
int
,
nargs
=
"+"
,
default
=
BenchArgs
.
output_len
)
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
BenchArgs
.
temperature
)
parser
.
add_argument
(
"--return-logprob"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--input-len-step-percentage"
,
type
=
float
,
default
=
BenchArgs
.
input_len_step_percentage
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
"--result-filename"
,
type
=
str
,
default
=
BenchArgs
.
result_filename
)
)
parser
.
add_argument
(
"--base-url"
,
type
=
str
,
default
=
BenchArgs
.
base_url
)
parser
.
add_argument
(
"--base-url"
,
type
=
str
,
default
=
BenchArgs
.
base_url
)
parser
.
add_argument
(
"--skip-warmup"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip-warmup"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--show-report"
,
action
=
"store_true"
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
@@ -99,36 +112,89 @@ def run_one_case(
...
@@ -99,36 +112,89 @@ def run_one_case(
batch_size
:
int
,
batch_size
:
int
,
input_len
:
int
,
input_len
:
int
,
output_len
:
int
,
output_len
:
int
,
temperature
:
float
,
return_logprob
:
bool
,
input_len_step_percentage
:
float
,
run_name
:
str
,
run_name
:
str
,
result_filename
:
str
,
result_filename
:
str
,
):
):
requests
.
post
(
url
+
"/flush_cache"
)
input_lens
=
[
int
(
input_len
*
(
1
+
(
i
-
(
batch_size
-
1
)
/
2
)
*
input_len_step_percentage
))
for
i
in
range
(
batch_size
)
]
input_ids
=
[
input_ids
=
[
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
input_len
,))]
[
int
(
x
)
for
x
in
np
.
random
.
randint
(
0
,
high
=
16384
,
size
=
(
input_len
s
[
i
]
,))]
for
_
in
range
(
batch_size
)
for
i
in
range
(
batch_size
)
]
]
use_structured_outputs
=
False
if
use_structured_outputs
:
texts
=
[]
for
_
in
range
(
batch_size
):
texts
.
append
(
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.
\n
"
*
50
+
"Assistant:"
)
json_schema
=
"$$ANY$$"
else
:
json_schema
=
None
tic
=
time
.
time
()
tic
=
time
.
time
()
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
# "text": texts,
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
temperature
,
"max_new_tokens"
:
output_len
,
"max_new_tokens"
:
output_len
,
"ignore_eos"
:
True
,
"ignore_eos"
:
True
,
"json_schema"
:
json_schema
,
},
},
"return_logprob"
:
return_logprob
,
"stream"
:
True
,
},
},
stream
=
True
,
)
)
latency
=
time
.
time
()
-
tic
_
=
response
.
json
()
# The TTFT of the last request in the batch
output_throughput
=
batch_size
*
output_len
/
latency
ttft
=
0.0
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
"error"
in
data
:
raise
RuntimeError
(
f
"Request has failed.
{
data
}
."
)
assert
(
data
[
"meta_info"
][
"finish_reason"
]
is
None
or
data
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"length"
)
if
data
[
"meta_info"
][
"completion_tokens"
]
==
1
:
ttft
=
time
.
time
()
-
tic
latency
=
time
.
time
()
-
tic
input_throughput
=
batch_size
*
input_len
/
ttft
output_throughput
=
batch_size
*
output_len
/
(
latency
-
ttft
)
overall_throughput
=
batch_size
*
(
input_len
+
output_len
)
/
latency
overall_throughput
=
batch_size
*
(
input_len
+
output_len
)
/
latency
server_info
=
requests
.
get
(
url
+
"/get_server_info"
).
json
()
acc_length
=
server_info
[
"internal_states"
][
0
].
get
(
"avg_spec_accept_length"
,
None
)
last_gen_throughput
=
server_info
[
"internal_states"
][
0
][
"last_gen_throughput"
]
print
(
f
"batch size:
{
batch_size
}
"
)
print
(
f
"batch size:
{
batch_size
}
"
)
print
(
f
"input_len:
{
input_len
}
"
)
print
(
f
"output_len:
{
output_len
}
"
)
print
(
f
"latency:
{
latency
:.
2
f
}
s"
)
print
(
f
"latency:
{
latency
:.
2
f
}
s"
)
print
(
f
"output throughput:
{
output_throughput
:.
2
f
}
token/s"
)
print
(
f
"ttft:
{
ttft
:.
2
f
}
s"
)
print
(
f
"(input + output) throughput:
{
overall_throughput
:.
2
f
}
token/s"
)
print
(
f
"Last generation throughput:
{
last_gen_throughput
:.
2
f
}
tok/s"
)
print
(
f
"Input throughput:
{
input_throughput
:.
2
f
}
tok/s"
)
if
output_len
!=
1
:
print
(
f
"output throughput:
{
output_throughput
:.
2
f
}
tok/s"
)
if
result_filename
:
if
result_filename
:
with
open
(
result_filename
,
"a"
)
as
fout
:
with
open
(
result_filename
,
"a"
)
as
fout
:
...
@@ -140,9 +206,21 @@ def run_one_case(
...
@@ -140,9 +206,21 @@ def run_one_case(
"latency"
:
round
(
latency
,
4
),
"latency"
:
round
(
latency
,
4
),
"output_throughput"
:
round
(
output_throughput
,
2
),
"output_throughput"
:
round
(
output_throughput
,
2
),
"overall_throughput"
:
round
(
overall_throughput
,
2
),
"overall_throughput"
:
round
(
overall_throughput
,
2
),
"last_gen_throughput"
:
round
(
last_gen_throughput
,
2
),
}
}
fout
.
write
(
json
.
dumps
(
res
)
+
"
\n
"
)
fout
.
write
(
json
.
dumps
(
res
)
+
"
\n
"
)
return
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
)
def
run_benchmark
(
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
):
def
run_benchmark
(
server_args
:
ServerArgs
,
bench_args
:
BenchArgs
):
if
bench_args
.
base_url
:
if
bench_args
.
base_url
:
...
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
# warmup
# warmup
if
not
bench_args
.
skip_warmup
:
if
not
bench_args
.
skip_warmup
:
print
(
"="
*
8
+
" Warmup Begin "
+
"="
*
8
)
run_one_case
(
run_one_case
(
base_url
,
base_url
,
batch_size
=
16
,
batch_size
=
16
,
input_len
=
1024
,
input_len
=
1024
,
output_len
=
16
,
output_len
=
16
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
""
,
run_name
=
""
,
result_filename
=
""
,
result_filename
=
""
,
)
)
print
(
"="
*
8
+
" Warmup End "
+
"="
*
8
+
"
\n
"
)
# benchmark
# benchmark
result
=
[]
try
:
try
:
for
bs
,
il
,
ol
in
itertools
.
product
(
for
bs
,
il
,
ol
in
itertools
.
product
(
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
bench_args
.
batch_size
,
bench_args
.
input_len
,
bench_args
.
output_len
):
):
run_one_case
(
result
.
append
(
base_url
,
run_one_case
(
bs
,
base_url
,
il
,
bs
,
ol
,
il
,
bench_args
.
run_name
,
ol
,
bench_args
.
result_filename
,
temperature
=
bench_args
.
temperature
,
return_logprob
=
bench_args
.
return_logprob
,
input_len_step_percentage
=
bench_args
.
input_len_step_percentage
,
run_name
=
bench_args
.
run_name
,
result_filename
=
bench_args
.
result_filename
,
)
)
)
finally
:
finally
:
if
proc
:
if
proc
:
...
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
...
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
print
(
f
"
\n
Results are saved to
{
bench_args
.
result_filename
}
"
)
if
not
bench_args
.
show_report
:
return
summary
=
" | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |
\n
"
summary
+=
"| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |
\n
"
for
(
batch_size
,
latency
,
ttft
,
input_throughput
,
output_throughput
,
overall_throughput
,
last_gen_throughput
,
acc_length
,
)
in
result
:
hourly_cost
=
2
*
server_args
.
tp_size
# $2/hour for one H100
input_util
=
0.7
accept_length
=
round
(
acc_length
,
2
)
if
acc_length
is
not
None
else
"n/a"
line
=
(
f
"|
{
batch_size
}
| "
f
"
{
latency
:.
2
f
}
| "
f
"
{
input_throughput
:.
2
f
}
| "
f
"
{
output_throughput
:.
2
f
}
| "
f
"
{
accept_length
}
| "
f
"
{
1
/
(
output_throughput
/
batch_size
)
*
1000
:.
2
f
}
| "
f
"
{
1e6
/
(
input_throughput
*
input_util
)
/
3600
*
hourly_cost
:.
2
f
}
| "
f
"
{
1e6
/
output_throughput
/
3600
*
hourly_cost
:.
2
f
}
|
\n
"
)
summary
+=
line
# print metrics table
print
(
summary
)
if
is_in_ci
():
write_github_step_summary
(
f
"### Test Nightly Benchmark (bench_one_batch)
\n
{
summary
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
python/sglang/bench_serving.py
View file @
fba8eccd
...
@@ -1103,7 +1103,7 @@ async def benchmark(
...
@@ -1103,7 +1103,7 @@ async def benchmark(
lora_names
:
List
[
str
],
lora_names
:
List
[
str
],
extra_request_body
:
Dict
[
str
,
Any
],
extra_request_body
:
Dict
[
str
,
Any
],
profile
:
bool
,
profile
:
bool
,
pd_sep
e
rated
:
bool
=
False
,
pd_sep
a
rated
:
bool
=
False
,
flush_cache
:
bool
=
False
,
flush_cache
:
bool
=
False
,
warmup_requests
:
int
=
1
,
warmup_requests
:
int
=
1
,
):
):
...
@@ -1239,12 +1239,14 @@ async def benchmark(
...
@@ -1239,12 +1239,14 @@ async def benchmark(
if
"sglang"
in
backend
:
if
"sglang"
in
backend
:
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
server_info
=
requests
.
get
(
base_url
+
"/get_server_info"
)
if
pd_sep
e
rated
:
if
pd_sep
a
rated
:
accept_length
=
server_info
.
json
()[
"decode"
][
0
].
get
(
accept_length
=
server_info
.
json
()[
"decode"
][
0
]
[
"internal_states"
][
0
]
.
get
(
"avg_spec_accept_length"
,
None
"avg_spec_accept_length"
,
None
)
)
else
:
else
:
accept_length
=
server_info
.
json
().
get
(
"avg_spec_accept_length"
,
None
)
accept_length
=
server_info
.
json
()[
"internal_states"
][
0
].
get
(
"avg_spec_accept_length"
,
None
)
else
:
else
:
accept_length
=
None
accept_length
=
None
...
@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
...
@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
lora_names
=
args
.
lora_name
,
lora_names
=
args
.
lora_name
,
extra_request_body
=
extra_request_body
,
extra_request_body
=
extra_request_body
,
profile
=
args
.
profile
,
profile
=
args
.
profile
,
pd_sep
e
rated
=
args
.
pd_sep
e
rated
,
pd_sep
a
rated
=
args
.
pd_sep
a
rated
,
flush_cache
=
args
.
flush_cache
,
flush_cache
=
args
.
flush_cache
,
)
)
)
)
...
...
python/sglang/srt/constrained/base_grammar_backend.py
View file @
fba8eccd
...
@@ -37,6 +37,12 @@ class BaseGrammarObject:
...
@@ -37,6 +37,12 @@ class BaseGrammarObject:
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
rollback
(
self
,
k
:
int
):
raise
NotImplementedError
()
def
is_terminated
(
self
):
raise
NotImplementedError
()
def
allocate_vocab_mask
(
def
allocate_vocab_mask
(
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
self
,
vocab_size
:
int
,
batch_size
:
int
,
device
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
fba8eccd
...
@@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids
,
next_token_ids
,
extend_input_len_per_req
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
extend_logprob_start_len_per_req
,
bid
,
)
=
(
)
=
(
result
.
logits_output
,
result
.
logits_output
,
result
.
next_token_ids
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
bid
,
)
)
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# wait
# wait
_
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
_
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
else
:
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
next_token_ids
=
result
.
next_token_ids
.
tolist
()
...
...
python/sglang/srt/entrypoints/engine.py
View file @
fba8eccd
...
@@ -330,7 +330,7 @@ class Engine(EngineBase):
...
@@ -330,7 +330,7 @@ class Engine(EngineBase):
return
{
return
{
**
dataclasses
.
asdict
(
self
.
tokenizer_manager
.
server_args
),
**
dataclasses
.
asdict
(
self
.
tokenizer_manager
.
server_args
),
**
self
.
scheduler_info
,
**
self
.
scheduler_info
,
**
internal_states
,
"internal_states"
:
internal_states
,
"version"
:
__version__
,
"version"
:
__version__
,
}
}
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
fba8eccd
...
@@ -222,7 +222,7 @@ async def get_server_info():
...
@@ -222,7 +222,7 @@ async def get_server_info():
return
{
return
{
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
dataclasses
.
asdict
(
_global_state
.
tokenizer_manager
.
server_args
),
**
_global_state
.
scheduler_info
,
**
_global_state
.
scheduler_info
,
**
internal_states
,
"internal_states"
:
internal_states
,
"version"
:
__version__
,
"version"
:
__version__
,
}
}
...
...
python/sglang/srt/layers/attention/utils.py
View file @
fba8eccd
...
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
...
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
# index into req_to_token_ptr needs to be int64
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
).
to
(
tl
.
int64
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
data
=
tl
.
load
(
req_to_token_ptr
req_to_token_ptr
...
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
...
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
num_pages_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_pages_loop
):
for
i
in
range
(
num_pages_loop
):
# index into req_to_token_ptr needs to be int64
paged_offset
=
(
paged_offset
=
(
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
+
i
*
NUM_PAGE_PER_BLOCK
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
.
to
(
tl
.
int64
)
+
i
*
NUM_PAGE_PER_BLOCK
)
*
PAGED_SIZE
)
*
PAGED_SIZE
paged_offset_out
=
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
+
i
*
NUM_PAGE_PER_BLOCK
paged_offset_out
=
tl
.
arange
(
0
,
NUM_PAGE_PER_BLOCK
)
+
i
*
NUM_PAGE_PER_BLOCK
...
...
python/sglang/srt/managers/scheduler.py
View file @
fba8eccd
...
@@ -160,6 +160,7 @@ class GenerationBatchResult:
...
@@ -160,6 +160,7 @@ class GenerationBatchResult:
extend_input_len_per_req
:
List
[
int
]
extend_input_len_per_req
:
List
[
int
]
extend_logprob_start_len_per_req
:
List
[
int
]
extend_logprob_start_len_per_req
:
List
[
int
]
bid
:
int
bid
:
int
can_run_cuda_graph
:
bool
@
dataclass
@
dataclass
...
@@ -323,13 +324,14 @@ class Scheduler(
...
@@ -323,13 +324,14 @@ class Scheduler(
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
# Print debug info
# Print debug info
logger
.
info
(
if
tp_rank
==
0
:
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
logger
.
info
(
f
"chunked_prefill_size=
{
server_args
.
chunked_prefill_size
}
, "
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"chunked_prefill_size=
{
server_args
.
chunked_prefill_size
}
, "
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
"
f
"max_running_requests=
{
self
.
max_running_requests
}
, "
)
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init memory pool and cache
# Init memory pool and cache
self
.
init_memory_pool_and_cache
()
self
.
init_memory_pool_and_cache
()
...
@@ -752,6 +754,7 @@ class Scheduler(
...
@@ -752,6 +754,7 @@ class Scheduler(
extend_input_len_per_req
=
None
,
extend_input_len_per_req
=
None
,
extend_logprob_start_len_per_req
=
None
,
extend_logprob_start_len_per_req
=
None
,
bid
=
bids
[
next_mb_id
],
bid
=
bids
[
next_mb_id
],
can_run_cuda_graph
=
result
.
can_run_cuda_graph
,
)
)
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
last_mbs
[
next_mb_id
]
=
mbs
[
next_mb_id
]
last_mbs
[
next_mb_id
]
=
mbs
[
next_mb_id
]
...
@@ -1159,7 +1162,9 @@ class Scheduler(
...
@@ -1159,7 +1162,9 @@ class Scheduler(
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
def
log_decode_stats
(
self
,
running_batch
=
None
):
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
):
batch
=
running_batch
or
self
.
running_batch
batch
=
running_batch
or
self
.
running_batch
gap_latency
=
time
.
time
()
-
self
.
last_decode_stats_tic
gap_latency
=
time
.
time
()
-
self
.
last_decode_stats_tic
...
@@ -1199,6 +1204,7 @@ class Scheduler(
...
@@ -1199,6 +1204,7 @@ class Scheduler(
msg
+=
f
"pre-allocated usage:
{
self
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
f
"pre-allocated usage:
{
self
.
num_tokens_pre_allocated
/
self
.
max_total_num_tokens
:.
2
f
}
, "
msg
+=
(
msg
+=
(
f
"cuda graph:
{
can_run_cuda_graph
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
self
.
last_gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
)
...
@@ -1524,11 +1530,11 @@ class Scheduler(
...
@@ -1524,11 +1530,11 @@ class Scheduler(
if
self
.
spec_algorithm
.
is_none
():
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
)
else
:
else
:
pp_hidden_states_proxy_tensors
,
_
=
(
pp_hidden_states_proxy_tensors
,
_
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
)
bid
=
model_worker_batch
.
bid
bid
=
model_worker_batch
.
bid
...
@@ -1538,6 +1544,7 @@ class Scheduler(
...
@@ -1538,6 +1544,7 @@ class Scheduler(
next_token_ids
,
next_token_ids
,
bid
,
bid
,
num_accepted_tokens
,
num_accepted_tokens
,
can_run_cuda_graph
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
spec_num_total_accepted_tokens
+=
(
self
.
spec_num_total_accepted_tokens
+=
(
num_accepted_tokens
+
batch
.
batch_size
()
num_accepted_tokens
+
batch
.
batch_size
()
...
@@ -1571,6 +1578,7 @@ class Scheduler(
...
@@ -1571,6 +1578,7 @@ class Scheduler(
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
bid
=
bid
,
bid
=
bid
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
fba8eccd
...
@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin:
...
@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin:
next_token_ids
,
next_token_ids
,
extend_input_len_per_req
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
extend_logprob_start_len_per_req
,
bid
,
)
=
(
)
=
(
result
.
logits_output
,
result
.
logits_output
,
result
.
next_token_ids
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
bid
,
)
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
(
logits_output
,
next_token_ids
,
_
=
(
self
.
tp_worker
.
resolve_last_batch_result
(
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
launch_done
,
)
)
)
else
:
else
:
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
...
@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin:
...
@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin:
result
:
GenerationBatchResult
,
result
:
GenerationBatchResult
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
):
):
logits_output
,
next_token_ids
,
bid
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
result
.
logits_output
,
result
.
logits_output
,
result
.
next_token_ids
,
result
.
next_token_ids
,
result
.
bid
,
result
.
can_run_cuda_graph
,
)
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_last_batch_result
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
launch_done
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
)
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
elif
batch
.
spec_algorithm
.
is_none
():
...
@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin:
self
.
attn_tp_rank
==
0
self
.
attn_tp_rank
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
):
):
self
.
log_decode_stats
(
running_batch
=
batch
)
self
.
log_decode_stats
(
can_run_cuda_graph
,
running_batch
=
batch
)
def
add_input_logprob_return_values
(
def
add_input_logprob_return_values
(
self
:
Scheduler
,
self
:
Scheduler
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fba8eccd
...
@@ -923,12 +923,13 @@ class TokenizerManager:
...
@@ -923,12 +923,13 @@ class TokenizerManager:
):
):
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
async
def
get_internal_state
(
self
)
->
Dict
[
Any
,
Any
]:
async
def
get_internal_state
(
self
)
->
List
[
Dict
[
Any
,
Any
]
]
:
req
=
GetInternalStateReq
()
req
=
GetInternalStateReq
()
res
:
List
[
GetInternalStateReqOutput
]
=
(
res
ponses
:
List
[
GetInternalStateReqOutput
]
=
(
await
self
.
get_internal_state_communicator
(
req
)
await
self
.
get_internal_state_communicator
(
req
)
)
)
return
res
[
0
].
internal_state
# Many DP ranks
return
[
res
.
internal_state
for
res
in
responses
]
def
get_log_request_metadata
(
self
):
def
get_log_request_metadata
(
self
):
max_length
=
None
max_length
=
None
...
...
python/sglang/srt/managers/tp_worker.py
View file @
fba8eccd
...
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
...
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.distributed
import
get_pp_group
,
get_tp_group
,
get_world_group
from
sglang.srt.distributed
import
get_pp_group
,
get_world_group
from
sglang.srt.hf_transformers_utils
import
(
from
sglang.srt.hf_transformers_utils
import
(
get_processor
,
get_processor
,
get_tokenizer
,
get_tokenizer
,
...
@@ -183,8 +183,11 @@ class TpModelWorker:
...
@@ -183,8 +183,11 @@ class TpModelWorker:
def
forward_batch_generation
(
def
forward_batch_generation
(
self
,
self
,
model_worker_batch
:
ModelWorkerBatch
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
pp_proxy_tensors
=
None
pp_proxy_tensors
=
None
...
@@ -196,11 +199,11 @@ class TpModelWorker:
...
@@ -196,11 +199,11 @@ class TpModelWorker:
)
)
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
logits_output
=
self
.
model_runner
.
forward
(
logits_output
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
)
if
model_worker_batch
.
launch_done
is
not
None
:
if
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
launch_done
.
set
()
if
skip_sample
:
if
skip_sample
:
next_token_ids
=
None
next_token_ids
=
None
...
@@ -209,17 +212,17 @@ class TpModelWorker:
...
@@ -209,17 +212,17 @@ class TpModelWorker:
logits_output
,
model_worker_batch
logits_output
,
model_worker_batch
)
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
else
:
else
:
pp_proxy_tensors
=
self
.
model_runner
.
forward
(
pp_proxy_tensors
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
forward_batch
,
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
return
pp_proxy_tensors
.
tensors
,
None
return
pp_proxy_tensors
.
tensors
,
None
,
can_run_cuda_graph
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
embeddings
=
logits_output
.
embeddings
return
embeddings
return
embeddings
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
fba8eccd
...
@@ -18,7 +18,7 @@ import logging
...
@@ -18,7 +18,7 @@ import logging
import
signal
import
signal
import
threading
import
threading
from
queue
import
Queue
from
queue
import
Queue
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
psutil
import
psutil
import
torch
import
torch
...
@@ -145,8 +145,10 @@ class TpModelWorkerClient:
...
@@ -145,8 +145,10 @@ class TpModelWorkerClient:
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# Run forward
# Run forward
logits_output
,
next_token_ids
=
self
.
worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
model_worker_batch
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
)
)
)
# Update the future token ids map
# Update the future token ids map
...
@@ -171,14 +173,18 @@ class TpModelWorkerClient:
...
@@ -171,14 +173,18 @@ class TpModelWorkerClient:
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
next_token_ids
=
next_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
)
copy_done
.
record
()
copy_done
.
record
()
self
.
output_queue
.
put
((
copy_done
,
logits_output
,
next_token_ids
))
self
.
output_queue
.
put
(
(
copy_done
,
logits_output
,
next_token_ids
,
can_run_cuda_graph
)
)
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
def
resolve_last_batch_result
(
self
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
):
"""
"""
This function is called to resolve the last batch result and
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
wait for the current batch to be launched. Used in overlap mode.
"""
"""
copy_done
,
logits_output
,
next_token_ids
=
self
.
output_queue
.
get
()
copy_done
,
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
output_queue
.
get
()
)
if
launch_done
is
not
None
:
if
launch_done
is
not
None
:
launch_done
.
wait
()
launch_done
.
wait
()
...
@@ -193,9 +199,11 @@ class TpModelWorkerClient:
...
@@ -193,9 +199,11 @@ class TpModelWorkerClient:
logits_output
.
input_token_logprobs
.
tolist
()
logits_output
.
input_token_logprobs
.
tolist
()
)
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
)
->
Tuple
[
None
,
torch
.
Tensor
,
bool
]:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info
=
model_worker_batch
.
sampling_info
sampling_info
=
model_worker_batch
.
sampling_info
sampling_info
.
update_penalties
()
sampling_info
.
update_penalties
()
...
@@ -223,7 +231,7 @@ class TpModelWorkerClient:
...
@@ -223,7 +231,7 @@ class TpModelWorkerClient:
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
=
(
self
.
future_token_ids_ct
+
bs
self
.
future_token_ids_ct
+
bs
)
%
self
.
future_token_ids_limit
)
%
self
.
future_token_ids_limit
return
None
,
future_next_token_ids
return
None
,
future_next_token_ids
,
False
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
fba8eccd
...
@@ -19,7 +19,7 @@ import bisect
...
@@ -19,7 +19,7 @@ import bisect
import
inspect
import
inspect
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
torch
import
torch
import
tqdm
import
tqdm
...
@@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
...
@@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
get_device_memory_capacity
,
get_device_memory_capacity
,
is_hip
,
rank0_log
,
rank0_log
,
)
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
_is_hip
=
is_hip
()
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
num_tokens
:
int
):
for
sub
in
model
.
_modules
.
values
():
for
sub
in
model
.
_modules
.
values
():
...
@@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
)
)
gpu_mem
=
get_device_memory_capacity
()
gpu_mem
=
get_device_memory_capacity
()
# Batch size of each rank will not become so large when DP is on
if
gpu_mem
is
not
None
and
gpu_mem
>
96
*
1024
:
if
gpu_mem
is
not
None
and
gpu_mem
>
96
*
1024
:
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
...
@@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner
.
req_to_token_pool
.
size
model_runner
.
req_to_token_pool
.
size
]
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
if
server_args
.
cuda_graph_max_bs
:
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
capture_bs
+=
list
(
range
(
max
(
capture_bs
),
server_args
.
cuda_graph_max_bs
+
1
,
16
)
)
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
compile_bs
=
(
compile_bs
=
(
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
if
server_args
.
enable_torch_compile
if
server_args
.
enable_torch_compile
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fba8eccd
...
@@ -1085,32 +1085,33 @@ class ModelRunner:
...
@@ -1085,32 +1085,33 @@ class ModelRunner:
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]
,
bool
]
:
can_run_cuda_graph
=
bool
(
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_cuda_graph
:
if
can_run_cuda_graph
:
ret
urn
self
.
cuda_graph_runner
.
replay
(
ret
=
self
.
cuda_graph_runner
.
replay
(
forward_batch
,
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
elif
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
ret
=
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
return
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
elif
forward_batch
.
forward_mode
.
is_extend
():
elif
forward_batch
.
forward_mode
.
is_extend
():
ret
urn
self
.
forward_extend
(
ret
=
self
.
forward_extend
(
forward_batch
,
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
elif
forward_batch
.
forward_mode
.
is_idle
():
elif
forward_batch
.
forward_mode
.
is_idle
():
ret
urn
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
ret
=
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
else
:
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
return
ret
,
can_run_cuda_graph
def
_preprocess_logits
(
def
_preprocess_logits
(
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
self
,
logits_output
:
LogitsProcessorOutput
,
sampling_info
:
SamplingBatchInfo
):
):
...
...
python/sglang/srt/server_args.py
View file @
fba8eccd
...
@@ -1086,7 +1086,7 @@ class ServerArgs:
...
@@ -1086,7 +1086,7 @@ class ServerArgs:
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
type
=
int
,
type
=
int
,
default
=
ServerArgs
.
cuda_graph_max_bs
,
default
=
ServerArgs
.
cuda_graph_max_bs
,
help
=
"Set the maximum batch size for cuda graph."
,
help
=
"Set the maximum batch size for cuda graph.
It will extend the cuda graph capture batch size to this value.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--cuda-graph-bs"
,
"--cuda-graph-bs"
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
fba8eccd
...
@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
=
self
.
draft
(
batch
)
spec_info
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
logits_output
,
verify_output
,
model_worker_batch
,
can_run_cuda_graph
=
(
batch
,
spec_info
self
.
verify
(
batch
,
spec_info
)
)
)
# If it is None, it means all requests are finished
# If it is None, it means all requests are finished
...
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
...
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
verify_output
.
verified_id
,
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
sum
(
verify_output
.
accept_length_per_req_cpu
),
can_run_cuda_graph
,
)
)
elif
batch
.
forward_mode
.
is_idle
():
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
,
_
=
(
model_worker_batch
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
0
,
False
else
:
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
)
return
logits_output
,
next_token_ids
,
bid
,
0
return
logits_output
,
next_token_ids
,
bid
,
0
,
False
def
forward_target_extend
(
def
forward_target_extend
(
self
,
batch
:
ScheduleBatch
self
,
batch
:
ScheduleBatch
...
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
...
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
spec_info
batch
.
spec_info
=
spec_info
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
_
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
_
,
can_run_cuda_graph
=
(
model_worker_batch
,
skip_sample
=
True
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
hidden_states
=
logits_output
.
hidden_states
...
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
self
.
add_logprob_values
(
batch
,
res
,
logits_output
)
self
.
add_logprob_values
(
batch
,
res
,
logits_output
)
return
logits_output
,
res
,
model_worker_batch
return
logits_output
,
res
,
model_worker_batch
,
can_run_cuda_graph
def
add_logprob_values
(
def
add_logprob_values
(
self
,
self
,
...
@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch
,
self
.
draft_model_runner
model_worker_batch
,
self
.
draft_model_runner
)
)
forward_batch
.
return_logprob
=
False
forward_batch
.
return_logprob
=
False
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
assert
forward_batch
.
spec_info
is
batch
.
spec_info
...
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
)
)
# Run
# Run
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
...
...
python/sglang/test/test_utils.py
View file @
fba8eccd
...
@@ -395,12 +395,12 @@ def popen_launch_server(
...
@@ -395,12 +395,12 @@ def popen_launch_server(
other_args
:
list
[
str
]
=
(),
other_args
:
list
[
str
]
=
(),
env
:
Optional
[
dict
]
=
None
,
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
pd_sep
e
rated
:
bool
=
False
,
pd_sep
a
rated
:
bool
=
False
,
):
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
host
=
host
[
2
:]
if
pd_sep
e
rated
:
if
pd_sep
a
rated
:
command
=
"sglang.launch_pd_server"
command
=
"sglang.launch_pd_server"
else
:
else
:
command
=
"sglang.launch_server"
command
=
"sglang.launch_server"
...
@@ -414,7 +414,7 @@ def popen_launch_server(
...
@@ -414,7 +414,7 @@ def popen_launch_server(
*
[
str
(
x
)
for
x
in
other_args
],
*
[
str
(
x
)
for
x
in
other_args
],
]
]
if
pd_sep
e
rated
:
if
pd_sep
a
rated
:
command
.
extend
(
command
.
extend
(
[
[
"--lb-host"
,
"--lb-host"
,
...
@@ -656,7 +656,7 @@ def get_benchmark_args(
...
@@ -656,7 +656,7 @@ def get_benchmark_args(
disable_stream
=
False
,
disable_stream
=
False
,
disable_ignore_eos
=
False
,
disable_ignore_eos
=
False
,
seed
:
int
=
0
,
seed
:
int
=
0
,
pd_sep
e
rated
:
bool
=
False
,
pd_sep
a
rated
:
bool
=
False
,
):
):
return
SimpleNamespace
(
return
SimpleNamespace
(
backend
=
"sglang"
,
backend
=
"sglang"
,
...
@@ -686,7 +686,7 @@ def get_benchmark_args(
...
@@ -686,7 +686,7 @@ def get_benchmark_args(
profile
=
None
,
profile
=
None
,
lora_name
=
None
,
lora_name
=
None
,
prompt_suffix
=
""
,
prompt_suffix
=
""
,
pd_sep
e
rated
=
pd_sep
e
rated
,
pd_sep
a
rated
=
pd_sep
a
rated
,
)
)
...
@@ -750,7 +750,7 @@ def run_bench_serving_multi(
...
@@ -750,7 +750,7 @@ def run_bench_serving_multi(
other_server_args
,
other_server_args
,
benchmark_args
,
benchmark_args
,
need_warmup
=
False
,
need_warmup
=
False
,
pd_sep
e
rated
=
False
,
pd_sep
a
rated
=
False
,
):
):
# Launch the server
# Launch the server
process
=
popen_launch_server
(
process
=
popen_launch_server
(
...
@@ -758,7 +758,7 @@ def run_bench_serving_multi(
...
@@ -758,7 +758,7 @@ def run_bench_serving_multi(
base_url
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_server_args
,
other_args
=
other_server_args
,
pd_sep
e
rated
=
pd_sep
e
rated
,
pd_sep
a
rated
=
pd_sep
a
rated
,
)
)
# run benchmark for all
# run benchmark for all
...
...
test/srt/run_suite.py
View file @
fba8eccd
...
@@ -101,8 +101,8 @@ suites = {
...
@@ -101,8 +101,8 @@ suites = {
# TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
# TestFile("test_disaggregation.py", 90),
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_disaggregation.py"
,
90
),
TestFile
(
"test_full_deepseek_v3.py"
,
250
),
TestFile
(
"test_full_deepseek_v3.py"
,
250
),
TestFile
(
"test_pp_single_node.py"
,
150
),
TestFile
(
"test_pp_single_node.py"
,
150
),
],
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment