Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
869f1c02
Unverified
Commit
869f1c02
authored
Oct 13, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 13, 2024
Browse files
Add a test case to test retract (#1662)
parent
2725f8da
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
1 deletion
+50
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_retract_decode.py
test/srt/test_retract_decode.py
+41
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
869f1c02
...
...
@@ -590,9 +590,11 @@ class ScheduleBatch:
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
first_iter
=
True
while
(
self
.
token_to_kv_pool
.
available_size
()
<
len
(
sorted_indices
)
*
global_config
.
retract_decode_steps
or
first_iter
):
if
len
(
sorted_indices
)
==
1
:
# Corner case: only one request left
...
...
@@ -601,6 +603,7 @@ class ScheduleBatch:
),
"No space left for only one request"
break
first_iter
=
False
idx
=
sorted_indices
.
pop
()
req
=
self
.
reqs
[
idx
]
retracted_reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
869f1c02
...
...
@@ -77,6 +77,9 @@ logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
# Test retract decode
test_retract
=
os
.
getenv
(
"SGLANG_TEST_RETRACT"
,
"false"
)
==
"true"
class
Scheduler
:
"""A scheduler that manages a tensor parallel GPU worker."""
...
...
@@ -611,10 +614,11 @@ class Scheduler:
return
new_batch
def
update_running_batch
(
self
):
global
test_retract
batch
=
self
.
running_batch
# Check if decode out of memory
if
not
batch
.
check_decode_mem
():
if
not
batch
.
check_decode_mem
()
or
(
test_retract
and
batch
.
batch_size
()
>
10
)
:
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
...
...
test/srt/run_suite.py
View file @
869f1c02
...
...
@@ -17,6 +17,7 @@ suites = {
"test_large_max_new_tokens.py"
,
"test_openai_server.py"
,
"test_pytorch_sampling_backend.py"
,
"test_retract_decode.py"
,
"test_server_args.py"
,
"test_skip_tokenizer_init.py"
,
"test_srt_engine.py"
,
...
...
test/srt/test_retract_decode.py
0 → 100644
View file @
869f1c02
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestRetractDecode
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment