Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
287d07a6
"tests/python/vscode:/vscode.git/clone" did not exist on "b0309326c85f134e5d940c72a23f069b3cbe36e8"
Commit
287d07a6
authored
Jan 20, 2025
by
Lianmin Zheng
Browse files
Misc fixes for eagle (flush_cache, CPU overhead) (#3014)
parent
d2571dd5
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
132 additions
and
95 deletions
+132
-95
python/sglang/bench_offline_throughput.py
python/sglang/bench_offline_throughput.py
+17
-11
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+47
-44
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-5
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-2
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+28
-15
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+11
-13
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+7
-0
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+2
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+4
-3
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+1
-0
No files found.
python/sglang/bench_offline_throughput.py
View file @
287d07a6
...
@@ -49,12 +49,13 @@ class BenchArgs:
...
@@ -49,12 +49,13 @@ class BenchArgs:
gsp_system_prompt_len
:
int
=
2048
gsp_system_prompt_len
:
int
=
2048
gsp_question_len
:
int
=
128
gsp_question_len
:
int
=
128
gsp_output_len
:
int
=
256
gsp_output_len
:
int
=
256
seed
:
int
=
1
disable_ignore_eos
:
bool
=
False
disable_ignore_eos
:
bool
=
False
extra_request_body
:
Optional
[
str
]
=
None
extra_request_body
:
Optional
[
str
]
=
None
seed
:
int
=
1
apply_chat_template
:
bool
=
False
profile
:
bool
=
False
skip_warmup
:
bool
=
False
skip_warmup
:
bool
=
False
do_not_exit
:
bool
=
False
do_not_exit
:
bool
=
False
profile
:
bool
=
False
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
@@ -141,20 +142,31 @@ class BenchArgs:
...
@@ -141,20 +142,31 @@ class BenchArgs:
default
=
BenchArgs
.
gsp_output_len
,
default
=
BenchArgs
.
gsp_output_len
,
help
=
"Target length in tokens for outputs in generated-shared-prefix dataset"
,
help
=
"Target length in tokens for outputs in generated-shared-prefix dataset"
,
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1
,
help
=
"The random seed."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-ignore-eos"
,
"--disable-ignore-eos"
,
type
=
bool
,
action
=
"store_true"
,
default
=
BenchArgs
.
disable_ignore_eos
,
help
=
"Disable ignore EOS token"
,
help
=
"Disable ignore EOS token"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--extra-request-body"
,
"--extra-request-body"
,
metavar
=
'{"key1": "value1", "key2": "value2"}'
,
metavar
=
'{"key1": "value1", "key2": "value2"}'
,
type
=
str
,
type
=
str
,
default
=
BenchArgs
.
extra_request_body
,
help
=
"Append given JSON object to the request payload. You can use this to specify"
help
=
"Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params."
,
"additional generate params like sampling params."
,
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1
,
help
=
"The random seed."
)
parser
.
add_argument
(
"--apply-chat-template"
,
action
=
"store_true"
,
help
=
"Apply chat template"
,
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--skip-warmup"
,
"--skip-warmup"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -165,12 +177,6 @@ class BenchArgs:
...
@@ -165,12 +177,6 @@ class BenchArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Do not exit the program. This is useful for nsys profile with --duration and --delay."
,
help
=
"Do not exit the program. This is useful for nsys profile with --duration and --delay."
,
)
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
help
=
"Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/bench_serving.py
View file @
287d07a6
...
@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
...
@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
fixed_output_len
=
args
.
sharegpt_output_len
,
fixed_output_len
=
args
.
sharegpt_output_len
,
context_len
=
args
.
sharegpt_context_len
,
context_len
=
args
.
sharegpt_context_len
,
apply_chat_template
=
args
.
apply_chat_template
,
)
)
elif
args
.
dataset_name
==
"random"
:
elif
args
.
dataset_name
==
"random"
:
input_requests
=
sample_random_requests
(
input_requests
=
sample_random_requests
(
...
@@ -517,6 +518,7 @@ class BenchmarkMetrics:
...
@@ -517,6 +518,7 @@ class BenchmarkMetrics:
median_e2e_latency_ms
:
float
median_e2e_latency_ms
:
float
std_e2e_latency_ms
:
float
std_e2e_latency_ms
:
float
p99_e2e_latency_ms
:
float
p99_e2e_latency_ms
:
float
concurrency
:
float
SHAREGPT_URL
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
SHAREGPT_URL
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
...
@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
...
@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
]
=
None
,
fixed_output_len
:
Optional
[
int
]
=
None
,
context_len
:
Optional
[
int
]
=
None
,
context_len
:
Optional
[
int
]
=
None
,
apply_chat_template
=
False
,
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
...
@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
...
@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt
=
dataset
[
i
][
0
]
if
apply_chat_template
:
prompt
=
tokenizer
.
apply_chat_template
(
[{
"role"
:
"user"
,
"content"
:
prompt
}],
add_generation_prompt
=
True
,
tokenize
=
False
,
)
prompt
=
prompt
.
replace
(
tokenizer
.
bos_token
,
""
)
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
completion
=
dataset
[
i
][
1
]
completion
=
dataset
[
i
][
1
]
completion_token_ids
=
tokenizer
.
encode
(
completion
)
completion_token_ids
=
tokenizer
.
encode
(
completion
)
...
@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
...
@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
)
)
if
prompt_len
<
1
or
output_len
<
1
:
if
prompt_len
<
2
or
output_len
<
2
:
# Prune too short sequences.
# Prune too short sequences.
continue
continue
...
@@ -880,6 +892,7 @@ def calculate_metrics(
...
@@ -880,6 +892,7 @@ def calculate_metrics(
median_e2e_latency_ms
=
np
.
median
(
e2e_latencies
)
*
1000
,
median_e2e_latency_ms
=
np
.
median
(
e2e_latencies
)
*
1000
,
std_e2e_latency_ms
=
np
.
std
(
e2e_latencies
)
*
1000
,
std_e2e_latency_ms
=
np
.
std
(
e2e_latencies
)
*
1000
,
p99_e2e_latency_ms
=
np
.
percentile
(
e2e_latencies
,
99
)
*
1000
,
p99_e2e_latency_ms
=
np
.
percentile
(
e2e_latencies
,
99
)
*
1000
,
concurrency
=
np
.
sum
(
e2e_latencies
)
/
dur_s
,
)
)
return
metrics
,
output_lens
return
metrics
,
output_lens
...
@@ -1031,6 +1044,7 @@ async def benchmark(
...
@@ -1031,6 +1044,7 @@ async def benchmark(
"Total token throughput (tok/s):"
,
metrics
.
total_throughput
"Total token throughput (tok/s):"
,
metrics
.
total_throughput
)
)
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Concurrency:"
,
metrics
.
concurrency
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"End-to-End Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"End-to-End Latency"
,
n
=
50
,
c
=
"-"
))
print
(
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean E2E Latency (ms):"
,
metrics
.
mean_e2e_latency_ms
)
"{:<40} {:<10.2f}"
.
format
(
"Mean E2E Latency (ms):"
,
metrics
.
mean_e2e_latency_ms
)
...
@@ -1062,13 +1076,24 @@ async def benchmark(
...
@@ -1062,13 +1076,24 @@ async def benchmark(
and
metrics
.
output_throughput
is
not
None
and
metrics
.
output_throughput
is
not
None
):
):
result
=
{
result
=
{
# Arguments
"backend"
:
args
.
backend
,
"backend"
:
args
.
backend
,
"dataset_name"
:
args
.
dataset_name
,
"dataset_name"
:
args
.
dataset_name
,
"request_rate"
:
request_rate
,
"request_rate"
:
request_rate
,
"max_concurrency"
:
max_concurrency
,
"max_concurrency"
:
max_concurrency
,
"sharegpt_output_len"
:
args
.
sharegpt_output_len
,
"random_input_len"
:
args
.
random_input_len
,
"random_output_len"
:
args
.
random_output_len
,
"random_range_ratio"
:
args
.
random_range_ratio
,
# Results
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
"total_input_tokens"
:
metrics
.
total_input
,
"total_input_tokens"
:
metrics
.
total_input
,
"total_output_tokens"
:
metrics
.
total_output
,
"total_output_tokens"
:
metrics
.
total_output
,
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"request_throughput"
:
metrics
.
request_throughput
,
"input_throughput"
:
metrics
.
input_throughput
,
"output_throughput"
:
metrics
.
output_throughput
,
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
"std_e2e_latency_ms"
:
metrics
.
std_e2e_latency_ms
,
"std_e2e_latency_ms"
:
metrics
.
std_e2e_latency_ms
,
...
@@ -1085,14 +1110,7 @@ async def benchmark(
...
@@ -1085,14 +1110,7 @@ async def benchmark(
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"std_itl_ms"
:
metrics
.
std_itl_ms
,
"std_itl_ms"
:
metrics
.
std_itl_ms
,
"p99_itl_ms"
:
metrics
.
p99_itl_ms
,
"p99_itl_ms"
:
metrics
.
p99_itl_ms
,
"input_throughput"
:
metrics
.
input_throughput
,
"concurrency"
:
metrics
.
concurrency
,
"output_throughput"
:
metrics
.
output_throughput
,
"sharegpt_output_len"
:
args
.
sharegpt_output_len
,
"random_input_len"
:
args
.
random_input_len
,
"random_output_len"
:
args
.
random_output_len
,
"random_range_ratio"
:
args
.
random_range_ratio
,
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
}
}
else
:
else
:
print
(
f
"Error running benchmark for request rate:
{
request_rate
}
"
)
print
(
f
"Error running benchmark for request rate:
{
request_rate
}
"
)
...
@@ -1112,36 +1130,16 @@ async def benchmark(
...
@@ -1112,36 +1130,16 @@ async def benchmark(
with
open
(
output_file_name
,
"a"
)
as
file
:
with
open
(
output_file_name
,
"a"
)
as
file
:
file
.
write
(
json
.
dumps
(
result
)
+
"
\n
"
)
file
.
write
(
json
.
dumps
(
result
)
+
"
\n
"
)
result
=
{
result
.
update
(
"duration"
:
benchmark_duration
,
{
"completed"
:
metrics
.
completed
,
"input_lens"
:
[
output
.
prompt_len
for
output
in
outputs
],
"total_input_tokens"
:
metrics
.
total_input
,
"output_lens"
:
output_lens
,
"total_output_tokens"
:
metrics
.
total_output
,
"ttfts"
:
[
output
.
ttft
for
output
in
outputs
],
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"itls"
:
[
output
.
itl
for
output
in
outputs
],
"request_throughput"
:
metrics
.
request_throughput
,
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"input_throughput"
:
metrics
.
input_throughput
,
"errors"
:
[
output
.
error
for
output
in
outputs
],
"output_throughput"
:
metrics
.
output_throughput
,
}
"mean_ttft_ms"
:
metrics
.
mean_ttft_ms
,
)
"median_ttft_ms"
:
metrics
.
median_ttft_ms
,
"std_ttft_ms"
:
metrics
.
std_ttft_ms
,
"p99_ttft_ms"
:
metrics
.
p99_ttft_ms
,
"mean_tpot_ms"
:
metrics
.
mean_tpot_ms
,
"median_tpot_ms"
:
metrics
.
median_tpot_ms
,
"std_tpot_ms"
:
metrics
.
std_tpot_ms
,
"p99_tpot_ms"
:
metrics
.
p99_tpot_ms
,
"mean_itl_ms"
:
metrics
.
mean_itl_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"std_itl_ms"
:
metrics
.
std_itl_ms
,
"p99_itl_ms"
:
metrics
.
p99_itl_ms
,
"input_lens"
:
[
output
.
prompt_len
for
output
in
outputs
],
"output_lens"
:
output_lens
,
"ttfts"
:
[
output
.
ttft
for
output
in
outputs
],
"itls"
:
[
output
.
itl
for
output
in
outputs
],
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
}
return
result
return
result
...
@@ -1422,7 +1420,6 @@ if __name__ == "__main__":
...
@@ -1422,7 +1420,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up."
,
"if the server is not processing requests fast enough to keep up."
,
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1
,
help
=
"The random seed."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--multi"
,
"--multi"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -1446,14 +1443,15 @@ if __name__ == "__main__":
...
@@ -1446,14 +1443,15 @@ if __name__ == "__main__":
help
=
"Disable streaming mode."
,
help
=
"Disable streaming mode."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
disable-ignore-eos
"
,
"--
return-logprob
"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"
Disable ignoring EOS
."
,
help
=
"
Return logprob
."
,
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1
,
help
=
"The random seed."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
return-logprob
"
,
"--
disable-ignore-eos
"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"
Return logprob
."
,
help
=
"
Disable ignoring EOS
."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--extra-request-body"
,
"--extra-request-body"
,
...
@@ -1462,6 +1460,11 @@ if __name__ == "__main__":
...
@@ -1462,6 +1460,11 @@ if __name__ == "__main__":
help
=
"Append given JSON object to the request payload. You can use this to specify"
help
=
"Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params."
,
"additional generate params like sampling params."
,
)
)
parser
.
add_argument
(
"--apply-chat-template"
,
action
=
"store_true"
,
help
=
"Apply chat template"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--profile"
,
"--profile"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
287d07a6
...
@@ -1023,7 +1023,7 @@ class Scheduler:
...
@@ -1023,7 +1023,7 @@ class Scheduler:
)
)
# Check for jump-forward
# Check for jump-forward
if
not
self
.
disable_jump_forward
:
if
not
self
.
disable_jump_forward
and
batch
.
has_grammar
:
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
pad_input_ids_func
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
self
.
waiting_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
if
batch
.
is_empty
():
...
@@ -1564,6 +1564,15 @@ class Scheduler:
...
@@ -1564,6 +1564,15 @@ class Scheduler:
self
.
grammar_backend
.
reset
()
self
.
grammar_backend
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
if
not
self
.
spec_algorithm
.
is_none
():
self
.
draft_worker
.
model_runner
.
req_to_token_pool
.
clear
()
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
.
clear
()
self
.
num_generated_tokens
=
0
self
.
forward_ct_decode
=
0
self
.
spec_num_total_accepted_tokens
=
0
self
.
spec_num_total_forward_ct
=
0
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
logger
.
info
(
"Cache flushed successfully!"
)
logger
.
info
(
"Cache flushed successfully!"
)
if_success
=
True
if_success
=
True
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
287d07a6
...
@@ -282,6 +282,9 @@ class ForwardBatch:
...
@@ -282,6 +282,9 @@ class ForwardBatch:
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
attn_backend
=
model_runner
.
attn_backend
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_info
=
batch
.
spec_info
,
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
...
@@ -336,11 +339,6 @@ class ForwardBatch:
...
@@ -336,11 +339,6 @@ class ForwardBatch:
if
model_runner
.
model_is_mrope
:
if
model_runner
.
model_is_mrope
:
ret
.
compute_mrope_positions
(
model_runner
,
batch
)
ret
.
compute_mrope_positions
(
model_runner
,
batch
)
# Init attention information
ret
.
req_to_token_pool
=
model_runner
.
req_to_token_pool
ret
.
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
ret
.
attn_backend
=
model_runner
.
attn_backend
# Init lora information
# Init lora information
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
...
...
python/sglang/srt/server.py
View file @
287d07a6
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
# Some shortcuts for backward compatbility.
# Some shortcuts for backward compat
i
bility.
# They will be removed in new versions.
# They will be removed in new versions.
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.entrypoints.engine
import
Engine
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.entrypoints.http_server
import
kill_process_tree
,
launch_server
python/sglang/srt/speculative/eagle_utils.py
View file @
287d07a6
...
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
...
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
class
EAGLEDraftInput
(
SpecInfo
):
class
EAGLEDraftInput
(
SpecInfo
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
prev_mode
=
ForwardMode
.
DECODE
self
.
prev_mode
=
ForwardMode
.
DECODE
self
.
sample_output
=
None
self
.
scores
:
torch
.
Tensor
=
None
self
.
scores
:
torch
.
Tensor
=
None
self
.
score_list
:
List
[
torch
.
Tensor
]
=
[]
self
.
score_list
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
self
.
cache_list
:
List
[
torch
.
Tenor
]
=
[]
self
.
cache_list
:
List
[
torch
.
Tenor
]
=
[]
self
.
iter
=
0
self
.
iter
=
0
# shape: (b, hidden_size)
self
.
hidden_states
:
torch
.
Tensor
=
None
self
.
hidden_states
:
torch
.
Tensor
=
None
# shape: (b,)
self
.
verified_id
:
torch
.
Tensor
=
None
self
.
verified_id
:
torch
.
Tensor
=
None
# shape: (b, vocab_size)
self
.
sample_output
:
torch
.
Tensor
=
None
self
.
positions
:
torch
.
Tensor
=
None
self
.
positions
:
torch
.
Tensor
=
None
self
.
accept_length
:
torch
.
Tensor
=
None
self
.
accept_length
:
torch
.
Tensor
=
None
self
.
has_finished
:
bool
=
False
self
.
accept_length_cpu
:
List
[
int
]
=
None
self
.
unfinished_index
:
List
[
int
]
=
None
def
load_server_args
(
self
,
server_args
:
ServerArgs
):
def
load_server_args
(
self
,
server_args
:
ServerArgs
):
self
.
topk
:
int
=
server_args
.
speculative_eagle_topk
self
.
topk
:
int
=
server_args
.
speculative_eagle_topk
...
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
:
pre_len
:
pre_len
]
=
req
.
prefix_indices
]
=
req
.
prefix_indices
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
pre_len
:
seq_len
]
=
(
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
pre_len
:
seq_len
]
=
(
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
)
)
...
@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
self
.
cache_list
.
append
(
batch
.
out_cache_loc
)
self
.
cache_list
.
append
(
batch
.
out_cache_loc
)
self
.
positions
=
(
self
.
positions
=
(
batch
.
seq_lens
[:,
None
]
batch
.
seq_lens
[:,
None
]
+
torch
.
ones
([
1
,
self
.
topk
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
self
.
iter
+
torch
.
full
(
[
1
,
self
.
topk
],
fill_value
=
self
.
iter
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
).
flatten
()
).
flatten
()
bs
=
len
(
batch
.
seq_lens
)
bs
=
len
(
batch
.
seq_lens
)
...
@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
self
.
verified_id
.
numel
())
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
self
.
verified_id
.
numel
())
batch
.
extend_lens
=
(
self
.
accept_length
+
1
).
tolist
()
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
seq_lens_cpu
=
batch
.
seq_lens
.
tolist
()
pt
=
0
pt
=
0
seq_lens
=
batch
.
seq_lens
.
tolist
()
i
=
0
i
=
0
for
req
in
batch
.
reqs
:
for
req
in
batch
.
reqs
:
if
req
.
finished
():
if
req
.
finished
():
continue
continue
# assert seq_len - pre_len == req.extend_input_len
# assert seq_len - pre_len == req.extend_input_len
input_len
=
self
.
accept_length
[
i
]
+
1
input_len
=
batch
.
extend_lens
[
i
]
seq_len
=
seq_lens
[
i
]
seq_len
=
seq_lens
_cpu
[
i
]
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
seq_len
-
input_len
:
seq_len
seq_len
-
input_len
:
seq_len
]
=
batch
.
out_cache_loc
[
pt
:
pt
+
input_len
]
]
=
batch
.
out_cache_loc
[
pt
:
pt
+
input_len
]
pt
+=
input_len
pt
+=
input_len
i
+=
1
i
+=
1
assert
pt
==
batch
.
out_cache_loc
.
shape
[
0
]
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
)
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
long
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
long
)
...
@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
triton
.
next_power_of_2
(
self
.
spec_steps
+
1
),
triton
.
next_power_of_2
(
self
.
spec_steps
+
1
),
)
)
batch
.
seq_lens_sum
=
sum
(
batch
.
seq_lens
)
batch
.
seq_lens_sum
=
sum
(
seq_lens
_cpu
)
batch
.
input_ids
=
self
.
verified_id
batch
.
input_ids
=
self
.
verified_id
self
.
verified_id
=
new_verified_id
self
.
verified_id
=
new_verified_id
...
@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
...
@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
finished_extend_len
=
{}
# {rid:accept_length + 1}
finished_extend_len
=
{}
# {rid:accept_length + 1}
accept_index_cpu
=
accept_index
.
tolist
()
accept_index_cpu
=
accept_index
.
tolist
()
predict_cpu
=
predict
.
tolist
()
predict_cpu
=
predict
.
tolist
()
has_finished
=
False
# iterate every accepted token and check if req has finished after append the token
# iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
# should be checked BEFORE free kv cache slots
for
i
,
(
req
,
accept_index_row
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_index_cpu
)):
for
i
,
(
req
,
accept_index_row
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_index_cpu
)):
...
@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
...
@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
finished_extend_len
[
req
.
rid
]
=
j
+
1
finished_extend_len
[
req
.
rid
]
=
j
+
1
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
draft_input
.
has_finished
=
True
has_finished
=
True
# set all tokens after finished token to -1 and break
# set all tokens after finished token to -1 and break
accept_index
[
i
,
j
+
1
:]
=
-
1
accept_index
[
i
,
j
+
1
:]
=
-
1
break
break
...
@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo):
...
@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo):
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
()
verified_id
=
predict
[
accept_index
]
verified_id
=
predict
[
accept_index
]
verified_id_cpu
=
verified_id
.
tolist
()
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
evict_mask
[
accept_index
]
=
False
...
@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo):
...
@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo):
draft_input
.
verified_id
=
predict
[
new_accept_index
]
draft_input
.
verified_id
=
predict
[
new_accept_index
]
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
new_accept_index
]
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
new_accept_index
]
draft_input
.
accept_length
=
accept_length
[
unfinished_index
]
draft_input
.
accept_length
=
accept_length
[
unfinished_index
]
draft_input
.
unfinished_index
=
unfinished_index
draft_input
.
accept_length_cpu
=
[
accept_length_cpu
[
i
]
for
i
in
unfinished_index
]
if
has_finished
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
]
else
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
return
(
return
(
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
287d07a6
...
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_utils
import
EAGLEDraftInput
from
sglang.srt.speculative.eagle_utils
import
EAGLEDraftInput
from
sglang.srt.utils
import
rank0_print
class
EAGLEWorker
(
TpModelWorker
):
class
EAGLEWorker
(
TpModelWorker
):
...
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
...
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
def
forward_draft_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_decode
(
self
,
batch
:
ScheduleBatch
):
batch
.
spec_info
.
prepare_for_decode
(
batch
)
batch
.
spec_info
.
prepare_for_decode
(
batch
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
...
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
...
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
seq_lens_backup
=
batch
.
seq_lens
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
if
batch
.
spec_info
.
has_finished
:
index
=
batch
.
spec_info
.
unfinished_index
seq_lens
=
batch
.
seq_lens
batch
.
seq_lens
=
batch
.
seq_lens
[
index
]
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
)
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
batch
.
forward_mode
=
ForwardMode
.
DECODE
if
batch
.
spec_info
.
has_finished
:
batch
.
seq_lens
=
seq_lens
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
seq_lens
=
seq_lens_backup
def
capture_for_decode
(
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
):
):
...
...
python/sglang/srt/utils.py
View file @
287d07a6
...
@@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool:
...
@@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool:
return
True
return
True
except
ValueError
:
except
ValueError
:
return
False
return
False
def
rank0_print
(
msg
:
str
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
if
get_tensor_model_parallel_rank
()
==
0
:
print
(
msg
,
flush
=
True
)
python/sglang/test/test_programs.py
View file @
287d07a6
...
@@ -535,7 +535,8 @@ def test_hellaswag_select():
...
@@ -535,7 +535,8 @@ def test_hellaswag_select():
# Compute accuracy
# Compute accuracy
accuracy_gen
=
np
.
mean
(
np
.
array
(
preds_gen
)
==
np
.
array
(
labels
))
accuracy_gen
=
np
.
mean
(
np
.
array
(
preds_gen
)
==
np
.
array
(
labels
))
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.1
print
(
f
"
{
accuracy
=
}
,
{
accuracy_gen
=
}
"
)
assert
np
.
abs
(
accuracy_gen
-
accuracy
)
<
0.05
assert
np
.
abs
(
latency_gen
-
latency
)
<
1
assert
np
.
abs
(
latency_gen
-
latency
)
<
1
return
accuracy
,
latency
return
accuracy
,
latency
...
...
python/sglang/test/test_utils.py
View file @
287d07a6
...
@@ -567,15 +567,16 @@ def run_bench_serving(
...
@@ -567,15 +567,16 @@ def run_bench_serving(
random_range_ratio
=
0.0
,
random_range_ratio
=
0.0
,
request_rate
=
request_rate
,
request_rate
=
request_rate
,
multi
=
None
,
multi
=
None
,
seed
=
0
,
output_file
=
None
,
output_file
=
None
,
disable_tqdm
=
False
,
disable_tqdm
=
False
,
disable_stream
=
disable_stream
,
disable_stream
=
disable_stream
,
disable_ignore_eos
=
False
,
return_logprob
=
False
,
return_logprob
=
False
,
lora_name
=
None
,
seed
=
0
,
disable_ignore_eos
=
False
,
extra_request_body
=
None
,
extra_request_body
=
None
,
apply_chat_template
=
False
,
profile
=
None
,
profile
=
None
,
lora_name
=
None
,
)
)
try
:
try
:
...
...
test/lang/test_srt_backend.py
View file @
287d07a6
"""
"""
Usage:
Usage:
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
"""
"""
import
unittest
import
unittest
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment