Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2e0424a
Unverified
Commit
a2e0424a
authored
Oct 31, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 31, 2024
Browse files
Fix memory leak for chunked prefill 2 (#1858)
Co-authored-by:
Liangsheng Yin
<
hnyls2002@gmail.com
>
parent
8ce202a4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
138 additions
and
30 deletions
+138
-30
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+3
-3
docs/hyperparameter_tuning.md
docs/hyperparameter_tuning.md
+0
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+18
-20
scripts/killall_sglang.sh
scripts/killall_sglang.sh
+1
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+112
-0
No files found.
.github/workflows/pr-test.yml
View file @
a2e0424a
...
...
@@ -50,7 +50,7 @@ jobs:
timeout-minutes
:
20
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end
5
python3 run_suite.py --suite minimal --range-begin 0 --range-end
4
unit-test-backend-part-2
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
@@ -67,7 +67,7 @@ jobs:
timeout-minutes
:
20
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin
5
--range-end 1
7
python3 run_suite.py --suite minimal --range-begin
4
--range-end 1
4
unit-test-backend-part-3
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
@@ -84,7 +84,7 @@ jobs:
timeout-minutes
:
20
run
:
|
cd test/srt
python3 run_suite.py --suite minimal --range-begin 1
7
--range-end 20
python3 run_suite.py --suite minimal --range-begin 1
4
--range-end 20
unit-test-backend-part-4
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
docs/hyperparameter_tuning.md
View file @
a2e0424a
# Guide on Hyperparameter Tuning
## Achieving Peak Throughput
Achieving a large batch size is the most important thing for attaining high throughput.
When the server is running at full load, look for the following in the log:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
a2e0424a
...
...
@@ -221,7 +221,7 @@ class Req:
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
is_
inflight_req
=
0
self
.
is_
being_chunked
=
0
# Logprobs (arguments)
self
.
return_logprob
=
False
...
...
@@ -888,7 +888,7 @@ class ScheduleBatch:
def
filter_batch
(
self
,
current_inflight
_req
:
Optional
[
Req
]
=
None
,
being_chunked
_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
if
keep_indices
is
None
:
...
...
@@ -896,7 +896,7 @@ class ScheduleBatch:
i
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
current_inflight
_req
and
self
.
reqs
[
i
]
is
not
being_chunked
_req
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
a2e0424a
...
...
@@ -231,7 +231,7 @@ class Scheduler:
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight
_req
=
None
self
.
being_chunked
_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
...
...
@@ -551,13 +551,13 @@ class Scheduler:
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
current_inflight
_req
:
if
self
.
being_chunked
_req
:
self
.
last_batch
.
filter_batch
(
current_inflight_req
=
self
.
current_inflight
_req
being_chunked_req
=
self
.
being_chunked
_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
current_inflight
_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
being_chunked
_req
)
# Inflight request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
current_inflight
_req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
self
.
being_chunked
_req
.
req_pool_idx
)
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
...
...
@@ -588,7 +588,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight
_req
is
None
:
)
and
self
.
being_chunked
_req
is
None
:
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
...
...
@@ -611,13 +611,11 @@ class Scheduler:
num_mixed_running
,
)
has_inflight
=
self
.
current_inflight
_req
is
not
None
has_inflight
=
self
.
being_chunked
_req
is
not
None
if
has_inflight
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
self
.
being_chunked_req
.
init_next_round_input
()
self
.
being_chunked_req
=
adder
.
add_inflight_req
(
self
.
being_chunked_req
)
if
self
.
lora_paths
:
...
...
@@ -661,11 +659,11 @@ class Scheduler:
]
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight
_req
is
None
self
.
current_inflight
_req
=
adder
.
new_inflight_req
assert
self
.
being_chunked
_req
is
None
self
.
being_chunked
_req
=
adder
.
new_inflight_req
if
self
.
current_inflight
_req
:
self
.
current_inflight_req
.
is_inflight_req
+=
1
if
self
.
being_chunked
_req
:
self
.
being_chunked_req
.
is_being_chunked
+=
1
# Print stats
if
self
.
tp_rank
==
0
:
...
...
@@ -833,8 +831,8 @@ class Scheduler:
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
is_
inflight_req
>
0
:
req
.
is_
inflight_req
-=
1
if
req
.
is_
being_chunked
>
0
:
req
.
is_
being_chunked
-=
1
else
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
...
...
@@ -860,8 +858,8 @@ class Scheduler:
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
.
is_
inflight_req
>
0
:
req
.
is_
inflight_req
-=
1
if
req
.
is_
being_chunked
>
0
:
req
.
is_
being_chunked
-=
1
else
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
...
...
scripts/killall_sglang.sh
View file @
a2e0424a
"""
Kill all SGLang processes and free the GPU memory.
"""
# Kill all SGLang processes and free the GPU memory.
kill
-9
$(
ps aux |
grep
'multiprocessing.spawn'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
kill
-9
$(
ps aux |
grep
'sglang.launch_server'
|
grep
-v
'grep'
|
awk
'{print $2}'
)
test/srt/run_suite.py
View file @
a2e0424a
...
...
@@ -19,6 +19,7 @@ suites = {
"test_openai_server.py"
,
"test_overlap_schedule.py"
,
"test_pytorch_sampling_backend.py"
,
"test_radix_attention.py"
,
"test_retract_decode.py"
,
"test_server_args.py"
,
"test_skip_tokenizer_init.py"
,
...
...
test/srt/test_radix_attention.py
0 → 100644
View file @
a2e0424a
import
os
import
random
import
unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
popen_launch_server
,
)
def
gen_radix_tree
(
num_nodes
=
400
,
chunk_len
=
256
):
num0
=
num_nodes
//
2
num1
=
num_nodes
-
num0
nodes
=
[{
"input_ids"
:
[
37
]
*
117
,
"decode_len"
:
217
}]
for
_
in
range
(
num0
):
parent
=
random
.
choice
(
nodes
)
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
while
num1
>
0
:
num_branch
=
random
.
randint
(
1
,
min
(
num1
,
10
))
parent
=
random
.
choice
(
nodes
)
for
_
in
range
(
num_branch
):
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
num1
-=
num_branch
random
.
shuffle
(
nodes
)
return
nodes
def
run_test
(
base_url
,
nodes
):
data
=
{
"input_ids"
:
[
node
[
"input_ids"
]
for
node
in
nodes
],
"sampling_params"
:
[
{
"max_new_tokens"
:
node
[
"decode_len"
],
"temperature"
:
0
}
for
node
in
nodes
],
}
res
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"fcfs"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_radix_attention
(
self
):
nodes
=
gen_radix_tree
()
run_test
(
self
.
base_url
,
nodes
)
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"lpm"
,
],
)
if
__name__
==
"__main__"
:
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment