Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5652c565
Unverified
Commit
5652c565
authored
Nov 24, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 24, 2024
Browse files
Update CI threshold & Improve code style (#2159)
parent
e3938b2f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
126 additions
and
41 deletions
+126
-41
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+34
-22
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+1
-0
python/sglang/srt/layers/fused_moe_patch.py
python/sglang/srt/layers/fused_moe_patch.py
+5
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+14
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+6
-6
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+59
-0
No files found.
.github/workflows/pr-test.yml
View file @
5652c565
...
...
@@ -50,7 +50,7 @@ jobs:
timeout-minutes
:
25
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end
5
python3 run_suite.py --suite minimal --range-begin 0 --range-end
6
unit-test-backend-part-2
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
@@ -67,7 +67,7 @@ jobs:
timeout-minutes
:
25
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin
5
--range-end 14
python3 run_suite.py --suite minimal --range-begin
6
--range-end 14
unit-test-backend-part-3
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
@@ -103,6 +103,31 @@ jobs:
cd test/srt
python3 run_suite.py --suite minimal --range-begin 21
unit-test-backend-2-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
2-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v3
-
name
:
Install dependencies
run
:
|
bash scripts/ci_install_dependency.sh
-
name
:
Evaluate data parallelism accuracy (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_data_parallelism.py
-
name
:
Evaluate MLA accuracy (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
performance-test-1-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
1-gpu-runner
...
...
@@ -178,23 +203,23 @@ jobs:
run
:
|
bash scripts/ci_install_dependency.sh
-
name
:
Benchmark
offline throughput
(TP=2)
-
name
:
Benchmark
single latency
(TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_
serving
.TestBench
Serving
.test_moe_
offline_throughput_
default
python3 -m unittest test_bench_
one_batch
.TestBench
OneBatch
.test_moe_default
-
name
:
Benchmark offline throughput
(w/o RadixAttention)
(TP=2)
-
name
:
Benchmark offline throughput (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_
without_radix_cache
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_
default
-
name
:
Benchmark
single latency
(TP=2)
-
name
:
Benchmark
offline throughput (w/o RadixAttention)
(TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 -m unittest test_bench_
one_batch
.TestBench
OneBatch
.test_moe_
default
python3 -m unittest test_bench_
serving
.TestBench
Serving
.test_moe_
offline_throughput_without_radix_cache
accuracy-test-1-gpu
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
@@ -238,23 +263,10 @@ jobs:
cd test/srt
python3 test_moe_eval_accuracy_large.py
-
name
:
Evaluate MLA accuracy (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
-
name
:
Evaluate data parallelism accuracy (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_data_parallelism.py
finish
:
needs
:
[
unit-test-frontend
,
unit-test-backend-part-1
,
unit-test-backend-part-2
,
unit-test-backend-part-3
,
unit-test-backend-part-4
,
unit-test-backend-2-gpu-part-1
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
accuracy-test-1-gpu
,
accuracy-test-2-gpu
]
...
...
python/sglang/bench_one_batch.py
View file @
5652c565
...
...
@@ -212,6 +212,7 @@ def extend(reqs, model_runner):
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
)
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/layers/fused_moe
/
patch.py
→
python/sglang/srt/layers/fused_moe
_
patch.py
View file @
5652c565
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from
typing
import
Callable
,
Optional
import
torch
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
5652c565
...
...
@@ -437,9 +437,12 @@ class ScheduleBatch:
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
=
None
#
For utility
#
Batch configs
model_config
:
ModelConfig
=
None
forward_mode
:
ForwardMode
=
None
enable_overlap
:
bool
=
False
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -488,10 +491,11 @@ class ScheduleBatch:
def
init_new
(
cls
,
reqs
:
List
[
Req
],
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
,
model_config
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
ReqToTokenPool
,
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
):
return
cls
(
reqs
=
reqs
,
...
...
@@ -499,6 +503,7 @@ class ScheduleBatch:
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
enable_overlap
=
enable_overlap
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
...
...
@@ -612,7 +617,7 @@ class ScheduleBatch:
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
def
prepare_for_extend
(
self
,
enable_overlap_schedule
:
bool
=
False
):
def
prepare_for_extend
(
self
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
len
(
self
.
reqs
)
...
...
@@ -706,7 +711,7 @@ class ScheduleBatch:
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
.
model_config
.
vocab_size
,
enable_overlap_schedule
=
enable_overlap
_schedule
,
enable_overlap_schedule
=
self
.
enable_overlap
,
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
...
...
@@ -897,7 +902,7 @@ class ScheduleBatch:
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
def
prepare_for_decode
(
self
):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
...
...
@@ -914,7 +919,7 @@ class ScheduleBatch:
else
:
locs
=
self
.
seq_lens
if
enable_overlap
:
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
locs
),
self
.
out_cache_loc
...
...
python/sglang/srt/managers/scheduler.py
View file @
5652c565
...
...
@@ -466,6 +466,7 @@ class Scheduler:
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
...
...
@@ -842,14 +843,15 @@ class Scheduler:
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
new_batch
.
prepare_for_extend
(
self
.
enable_overlap
)
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
filter_batch
()
if
not
self
.
running_batch
.
is_empty
():
self
.
running_batch
.
prepare_for_decode
(
self
.
enable_overlap
)
self
.
running_batch
.
prepare_for_decode
()
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
...
...
@@ -900,7 +902,7 @@ class Scheduler:
self
.
batch_is_full
=
False
# Update batch tensors
batch
.
prepare_for_decode
(
self
.
enable_overlap
)
batch
.
prepare_for_decode
()
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
...
...
@@ -1055,6 +1057,7 @@ class Scheduler:
continue
if
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5652c565
...
...
@@ -23,7 +23,7 @@ import torch
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.fused_moe
.
patch
import
fused_moe_forward_native
from
sglang.srt.layers.fused_moe
_
patch
import
fused_moe_forward_native
from
sglang.srt.layers.logits_processor
import
(
LogitsMetadata
,
LogitsProcessor
,
...
...
test/srt/test_bench_serving.py
View file @
5652c565
...
...
@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
28
50
)
self
.
assertGreater
(
res
[
"output_throughput"
],
33
50
)
def
test_offline_throughput_non_stream_small_batch_size
(
self
):
res
=
run_bench_serving
(
...
...
@@ -47,7 +47,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
290
0
)
self
.
assertGreater
(
res
[
"output_throughput"
],
335
0
)
def
test_offline_throughput_without_chunked_prefill
(
self
):
res
=
run_bench_serving
(
...
...
@@ -74,7 +74,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
29
50
)
self
.
assertGreater
(
res
[
"output_throughput"
],
34
50
)
def
test_offline_throughput_default_fp8
(
self
):
res
=
run_bench_serving
(
...
...
@@ -85,7 +85,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
3
20
0
)
self
.
assertGreater
(
res
[
"output_throughput"
],
3
85
0
)
def
test_online_latency_default
(
self
):
res
=
run_bench_serving
(
...
...
@@ -109,7 +109,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
190
0
)
self
.
assertGreater
(
res
[
"output_throughput"
],
215
0
)
def
test_moe_offline_throughput_without_radix_cache
(
self
):
res
=
run_bench_serving
(
...
...
@@ -120,7 +120,7 @@ class TestBenchServing(unittest.TestCase):
)
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
1
9
50
)
self
.
assertGreater
(
res
[
"output_throughput"
],
2
150
)
if
__name__
==
"__main__"
:
...
...
test/srt/test_srt_endpoint.py
View file @
5652c565
...
...
@@ -6,6 +6,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import
json
import
unittest
import
numpy
as
np
import
requests
from
sglang.srt.utils
import
kill_child_process
...
...
@@ -132,6 +133,7 @@ class TestSRTEndpoint(unittest.TestCase):
)
def
test_logprob_with_chunked_prefill
(
self
):
"""Test a long prompt that requests output logprobs will not hit OOM."""
new_tokens
=
4
prompts
=
"I have a very good idea on this. "
*
8000
...
...
@@ -154,6 +156,63 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_logprob_match
(
self
):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def
run_generate
(
prompt
,
return_logprob
=
False
,
max_new_tokens
=
512
,
logprob_start_len
=-
1
):
if
isinstance
(
prompt
,
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
**
prompt_kwargs
,
"sampling_params"
:
{
"temperature"
:
1.0
,
"max_new_tokens"
:
max_new_tokens
,
"ignore_eos"
:
True
,
},
"return_logprob"
:
return_logprob
,
"return_text_in_logprobs"
:
True
,
"logprob_start_len"
:
logprob_start_len
,
},
)
return
response
.
json
()
prompt
=
"I have a very good idea on how to"
gen
=
run_generate
(
prompt
,
return_logprob
=
True
,
logprob_start_len
=
0
)
output_logprobs
=
np
.
array
(
[
x
[
0
]
for
x
in
gen
[
"meta_info"
][
"output_token_logprobs"
]]
)
num_prompts_tokens
=
gen
[
"meta_info"
][
"prompt_tokens"
]
input_tokens
=
[
x
[
1
]
for
x
in
gen
[
"meta_info"
][
"input_token_logprobs"
]]
output_tokens
=
[
x
[
1
]
for
x
in
gen
[
"meta_info"
][
"output_token_logprobs"
]]
new_prompt
=
input_tokens
+
output_tokens
score
=
run_generate
(
new_prompt
,
return_logprob
=
True
,
logprob_start_len
=
0
,
max_new_tokens
=
0
)
output_logprobs_score
=
np
.
array
(
[
x
[
0
]
for
x
in
score
[
"meta_info"
][
"input_token_logprobs"
][
num_prompts_tokens
:]
]
)
print
(
f
"
{
output_logprobs
[
-
10
:]
=
}
"
)
print
(
f
"
{
output_logprobs_score
[
-
10
:]
=
}
"
)
diff
=
np
.
abs
(
output_logprobs
-
output_logprobs_score
)
max_diff
=
np
.
max
(
diff
)
self
.
assertLess
(
max_diff
,
0.2
)
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment