Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
1701b0db
"scripts/vscode:/vscode.git/clone" did not exist on "d0cff78f5473825461035921799ac0b6a4af558d"
Unverified
Commit
1701b0db
authored
Oct 24, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 24, 2024
Browse files
Enhance the test case for chunked prefill (#1785)
parent
384d85ba
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
161 additions
and
106 deletions
+161
-106
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+25
-5
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+93
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_chunked_prefill.py
test/srt/test_chunked_prefill.py
+10
-44
test/srt/test_large_max_new_tokens.py
test/srt/test_large_max_new_tokens.py
+25
-8
test/srt/test_overlap_schedule.py
test/srt/test_overlap_schedule.py
+7
-45
No files found.
.github/workflows/pr-test.yml
View file @
1701b0db
...
...
@@ -33,7 +33,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
-
name
:
Run test
timeout-minutes
:
2
0
timeout-minutes
:
1
0
run
:
|
cd test/lang
python3 run_suite.py --suite minimal
...
...
@@ -73,7 +73,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
2
0
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
...
...
@@ -93,10 +93,30 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
-
name
:
Run test
timeout-minutes
:
30
timeout-minutes
:
20
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin 17 --range-end 20
unit-test-backend-part-4
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
1-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v3
-
name
:
Install dependencies
run
:
|
pip install --upgrade pip
pip install -e "python[dev]"
pip install transformers==4.45.2
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
-
name
:
Run test
timeout-minutes
:
20
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin
17
python3 run_suite.py --suite minimal --range-begin
20
performance-test-1-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
@@ -263,7 +283,7 @@ jobs:
finish
:
needs
:
[
unit-test-frontend
,
unit-test-backend-part-1
,
unit-test-backend-part-2
,
unit-test-backend-part-3
,
unit-test-frontend
,
unit-test-backend-part-1
,
unit-test-backend-part-2
,
unit-test-backend-part-3
,
unit-test-backend-part-4
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
accuracy-test-1-gpu
,
accuracy-test-2-gpu
]
...
...
python/sglang/test/test_utils.py
View file @
1701b0db
...
...
@@ -3,6 +3,7 @@
import
argparse
import
asyncio
import
os
import
random
import
subprocess
import
threading
import
time
...
...
@@ -20,6 +21,7 @@ from sglang.global_config import global_config
from
sglang.lang.backend.openai
import
OpenAI
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.utils
import
get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/Meta-Llama-3.1-8B-FP8"
...
...
@@ -400,7 +402,7 @@ def popen_launch_server(
api_key
:
Optional
[
str
]
=
None
,
other_args
:
tuple
=
(),
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
bool
=
Fals
e
,
return_stdout_stderr
:
Optional
[
tuple
]
=
Non
e
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
...
...
@@ -423,8 +425,8 @@ def popen_launch_server(
if
return_stdout_stderr
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stdout
=
return_stdout_stderr
[
0
]
,
stderr
=
return_stdout_stderr
[
1
]
,
env
=
env
,
text
=
True
,
)
...
...
@@ -631,3 +633,91 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
rouge_l_scores
.
append
(
fmeasure
)
return
rouge_l_scores
STDOUT_FILENAME
=
"stdout.txt"
STDERR_FILENAME
=
"stderr.txt"
def
read_output
(
output_lines
):
pt
=
0
while
pt
>=
0
:
if
pt
>
0
and
os
.
path
.
exists
(
STDERR_FILENAME
):
break
lines
=
open
(
STDERR_FILENAME
).
readlines
()
output_lines
[:]
=
lines
for
line
in
lines
[
pt
:]:
print
(
line
,
end
=
""
,
flush
=
True
)
pt
+=
1
def
run_mmlu_test
(
disable_radix_cache
,
enable_mixed_chunk
=
False
,
enable_overlap
=
False
,
chunked_prefill_size
=
32
,
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
if
enable_overlap
:
other_args
+=
[
"--enable-overlap-scheduler"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
port
=
random
.
randint
(
4000
,
5000
)
base_url
=
f
"http://127.0.0.1:
{
port
}
"
# Create files and launch the server
stdout
=
open
(
STDOUT_FILENAME
,
"w"
)
stderr
=
open
(
STDERR_FILENAME
,
"w"
)
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
return_stdout_stderr
=
(
stdout
,
stderr
),
)
# Launch a thread to stream the output
output_lines
=
[]
t
=
threading
.
Thread
(
target
=
read_output
,
args
=
(
output_lines
,))
t
.
start
()
# Run the eval
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
128
,
num_threads
=
128
,
)
try
:
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
assert
metrics
[
"score"
]
>=
0.65
finally
:
pass
# Clean up everything
kill_child_process
(
process
.
pid
)
kill_child_process
(
process
.
pid
)
stdout
.
close
()
stderr
.
close
()
os
.
remove
(
STDOUT_FILENAME
)
os
.
remove
(
STDERR_FILENAME
)
t
.
join
()
# Assert success
has_new_server
=
False
has_leak
=
False
for
line
in
output_lines
:
if
"The server is fired"
in
line
:
has_new_server
=
True
if
"leak"
in
line
:
has_leak
=
True
assert
has_new_server
# assert not has_leak
test/srt/run_suite.py
View file @
1701b0db
...
...
@@ -15,7 +15,7 @@ suites = {
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
"test_json_constrained.py"
,
#
"test_large_max_new_tokens.py",
# This test hangs on CI due to unknown reasons
"test_large_max_new_tokens.py"
,
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
...
...
test/srt/test_chunked_prefill.py
View file @
1701b0db
"""
python3 -m unittest test_chunked_prefill.TestChunkedPrefill.test_mixed_chunked_prefill_without_radix_cache
"""
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
run_bench_serving
,
run_mmlu_test
,
)
class
TestChunkedPrefill
(
unittest
.
TestCase
):
def
run_mmlu
(
self
,
disable_radix_cache
,
enable_mixed_chunk
,
chunked_prefill_size
=
32
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
other_args
+=
[
"--enable-mixed-chunk"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
try
:
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
finally
:
kill_child_process
(
process
.
pid
)
def
test_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
)
run_mmlu
_test
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
)
def
test_mixed_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
True
)
run_mmlu
_test
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
True
)
def
test_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
)
run_mmlu
_test
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
False
)
def
test_mixed_chunked_prefill_without_radix_cache
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
True
)
run_mmlu
_test
(
disable_radix_cache
=
True
,
enable_mixed_chunk
=
True
)
def
test_no_chunked_prefill
(
self
):
self
.
run_mmlu
(
run_mmlu
_test
(
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
chunked_prefill_size
=-
1
)
...
...
test/srt/test_large_max_new_tokens.py
View file @
1701b0db
"""
python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_completion
"""
import
os
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
...
...
@@ -20,6 +24,10 @@ class TestLargeMaxNewTokens(unittest.TestCase):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
stdout
=
open
(
"stdout.txt"
,
"w"
)
cls
.
stderr
=
open
(
"stderr.txt"
,
"w"
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
...
...
@@ -27,7 +35,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
api_key
=
cls
.
api_key
,
other_args
=
(
"--max-total-token"
,
"1024"
,
"--context-len"
,
"8192"
),
env
=
{
"SGLANG_CLIP_MAX_NEW_TOKENS"
:
"256"
,
**
os
.
environ
},
return_stdout_stderr
=
True
,
return_stdout_stderr
=
(
cls
.
stdout
,
cls
.
stderr
)
,
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_MODEL_NAME_FOR_TEST
)
...
...
@@ -35,6 +43,10 @@ class TestLargeMaxNewTokens(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
cls
.
stdout
.
close
()
cls
.
stderr
.
close
()
os
.
remove
(
"stdout.txt"
)
os
.
remove
(
"stderr.txt"
)
def
run_chat_completion
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
...
@@ -56,16 +68,21 @@ class TestLargeMaxNewTokens(unittest.TestCase):
futures
=
[]
with
ThreadPoolExecutor
(
num_requests
)
as
executor
:
# Send multiple requests
for
i
in
range
(
num_requests
):
futures
.
append
(
executor
.
submit
(
self
.
run_chat_completion
))
all_requests_running
=
False
for
line
in
iter
(
self
.
process
.
stderr
.
readline
,
""
):
line
=
str
(
line
)
print
(
line
,
end
=
""
)
if
f
"#running-req:
{
num_requests
}
"
in
line
:
all_requests_running
=
True
break
# Ensure that they are running concurrently
pt
=
0
while
pt
>=
0
:
lines
=
open
(
"stderr.txt"
).
readlines
()
for
line
in
lines
[
pt
:]:
print
(
line
,
end
=
""
,
flush
=
True
)
if
f
"#running-req:
{
num_requests
}
"
in
line
:
all_requests_running
=
True
pt
=
-
1
break
pt
+=
1
assert
all_requests_running
...
...
test/srt/test_overlap_schedule.py
View file @
1701b0db
"""
Usage:
SGLANG_IS_IN_CI=true
python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill
SGLANG_IS_IN_CI=true
python3 test_overlap_schedule.py
python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill
python3 test_overlap_schedule.py
"""
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
from
sglang.test.test_utils
import
run_mmlu_test
class
TestOverlapSchedule
(
unittest
.
TestCase
):
def
run_mmlu
(
self
,
disable_radix_cache
,
chunked_prefill_size
=
32
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--enable-overlap-schedule"
]
model
=
DEFAULT_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
try
:
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
finally
:
kill_child_process
(
process
.
pid
)
def
test_no_radix_attention_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
)
run_mmlu
_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=
32
)
def
test_no_radix_attention_no_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
)
run_mmlu
_test
(
disable_radix_cache
=
True
,
chunked_prefill_size
=-
1
)
def
test_radix_attention_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
)
run_mmlu
_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=
32
)
def
test_radix_attention_no_chunked_prefill
(
self
):
self
.
run_mmlu
(
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
)
run_mmlu
_test
(
disable_radix_cache
=
False
,
chunked_prefill_size
=-
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
# @unittest.skip("did not support")
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment