Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
04b262cd
Unverified
Commit
04b262cd
authored
Oct 04, 2024
by
Ying Sheng
Committed by
GitHub
Oct 04, 2024
Browse files
[Fix] Fix major performance bug in certain cases (#1563)
Co-authored-by:
hnyls2002
<
hnyls2002@gmail.com
>
parent
2432ad40
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
18 deletions
+50
-18
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+6
-0
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+2
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+10
-8
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+14
-5
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+18
-3
No files found.
.github/workflows/pr-test.yml
View file @
04b262cd
...
...
@@ -130,6 +130,12 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
-
name
:
Benchmark Offline Throughput (Non-streaming, small batch size)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
performance-test-1-gpu-part-2
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
1-gpu-runner
...
...
python/sglang/bench_serving.py
View file @
04b262cd
...
...
@@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer
=
get_tokenizer
(
tokenizer_id
)
if
args
.
dataset_name
==
"sharegpt"
:
assert
args
.
random_input_len
is
None
and
args
.
random_output_len
is
None
input_requests
=
sample_sharegpt_requests
(
dataset_path
=
args
.
dataset_path
,
num_requests
=
args
.
num_prompts
,
...
...
@@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace):
fixed_output_len
=
args
.
sharegpt_output_len
,
)
elif
args
.
dataset_name
==
"random"
:
assert
args
.
random_input_len
is
not
None
and
args
.
random_output_len
is
not
None
input_requests
=
sample_random_requests
(
input_len
=
args
.
random_input_len
,
output_len
=
args
.
random_output_len
,
...
...
@@ -964,13 +966,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--random-input-len"
,
type
=
int
,
default
=
1024
,
help
=
"Number of input tokens per request, used only for random dataset."
,
)
parser
.
add_argument
(
"--random-output-len"
,
type
=
int
,
default
=
128
,
help
=
"Number of output tokens per request, used only for random dataset."
,
)
parser
.
add_argument
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
04b262cd
...
...
@@ -222,7 +222,7 @@ class Scheduler:
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
do_not_get_new_batch
=
False
self
.
batch_is_full
=
False
def
event_loop
(
self
):
while
True
:
...
...
@@ -261,12 +261,10 @@ class Scheduler:
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
...
...
@@ -279,11 +277,12 @@ class Scheduler:
@
torch
.
inference_mode
()
def
forward_step
(
self
):
if
self
.
do_not_get_new_batch
and
self
.
current_inflight_req
is
None
:
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight_req
is
None
:
new_batch
=
None
else
:
new_batch
=
self
.
get_new_prefill_batch
()
self
.
do_not_get_new_batch
=
False
if
new_batch
is
not
None
:
# Run a new prefill batch
...
...
@@ -447,6 +446,7 @@ class Scheduler:
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
)
if
running_bs
>=
self
.
max_running_requests
:
self
.
batch_is_full
=
True
return
None
# Get priority queue
...
...
@@ -490,9 +490,11 @@ class Scheduler:
)
>
self
.
max_loras_per_batch
):
self
.
batch_is_full
=
True
break
if
adder
.
no_remaining_tokens
():
self
.
batch_is_full
=
True
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
...
...
@@ -500,6 +502,7 @@ class Scheduler:
not
res
or
running_bs
+
len
(
adder
.
can_run_list
)
>=
self
.
max_running_requests
):
self
.
batch_is_full
=
True
break
can_run_list
=
adder
.
can_run_list
...
...
@@ -810,9 +813,6 @@ class Scheduler:
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
if
not
has_finished
:
self
.
do_not_get_new_batch
=
True
self
.
handle_finished_requests
(
batch
)
def
handle_finished_requests
(
self
,
batch
:
ScheduleBatch
):
...
...
@@ -833,6 +833,8 @@ class Scheduler:
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
finished
()
and
req
is
not
self
.
current_inflight_req
:
unfinished_indices
.
append
(
i
)
else
:
self
.
batch_is_full
=
False
if
req
.
finished
()
or
(
req
.
stream
...
...
python/sglang/test/test_utils.py
View file @
04b262cd
...
...
@@ -514,7 +514,16 @@ def get_similarities(vec1, vec2):
return
F
.
cosine_similarity
(
torch
.
tensor
(
vec1
),
torch
.
tensor
(
vec2
),
dim
=
0
)
def
run_bench_serving
(
model
,
num_prompts
,
request_rate
,
other_server_args
):
def
run_bench_serving
(
model
,
num_prompts
,
request_rate
,
other_server_args
,
dataset_name
=
"random"
,
random_input_len
=
4096
,
random_output_len
=
2048
,
disable_stream
=
False
,
):
# Launch the server
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
...
...
@@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
base_url
=
base_url
,
host
=
None
,
port
=
None
,
dataset_name
=
"random"
,
dataset_name
=
dataset_name
,
dataset_path
=
""
,
model
=
None
,
tokenizer
=
None
,
num_prompts
=
num_prompts
,
sharegpt_output_len
=
None
,
random_input_len
=
4096
,
random_output_len
=
2048
,
random_input_len
=
random_input_len
,
random_output_len
=
random_output_len
,
random_range_ratio
=
0.0
,
request_rate
=
request_rate
,
multi
=
None
,
seed
=
0
,
output_file
=
None
,
disable_tqdm
=
False
,
disable_stream
=
False
,
disable_stream
=
disable_stream
,
disable_ignore_eos
=
False
,
extra_request_body
=
None
,
)
...
...
test/srt/test_bench_serving.py
View file @
04b262cd
...
...
@@ -20,7 +20,22 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
2600
assert
res
[
"output_throughput"
]
>
2830
def
test_offline_throughput_non_stream_small_batch_size
(
self
):
res
=
run_bench_serving
(
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
num_prompts
=
200
,
request_rate
=
float
(
"inf"
),
dataset_name
=
"sharegpt"
,
random_input_len
=
None
,
random_output_len
=
None
,
disable_stream
=
True
,
other_server_args
=
[
"--max-running-requests"
,
"10"
],
)
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
1000
def
test_offline_throughput_without_radix_cache
(
self
):
res
=
run_bench_serving
(
...
...
@@ -31,7 +46,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
28
0
0
assert
res
[
"output_throughput"
]
>
28
8
0
def
test_offline_throughput_without_chunked_prefill
(
self
):
res
=
run_bench_serving
(
...
...
@@ -58,7 +73,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
assert
res
[
"output_throughput"
]
>
2
60
0
assert
res
[
"output_throughput"
]
>
2
93
0
def
test_offline_throughput_default_fp8
(
self
):
res
=
run_bench_serving
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment