Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b2ccf36d
Unverified
Commit
b2ccf36d
authored
Nov 28, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 28, 2024
Browse files
Fix memory leak during abort (#2238)
parent
d4fc1a70
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
84 additions
and
7 deletions
+84
-7
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+4
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-2
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+14
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_abort.py
test/srt/test_abort.py
+54
-0
No files found.
.github/workflows/pr-test.yml
View file @
b2ccf36d
...
@@ -50,7 +50,7 @@ jobs:
...
@@ -50,7 +50,7 @@ jobs:
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end
5
python3 run_suite.py --suite minimal --range-begin 0 --range-end
6
unit-test-backend-part-2
:
unit-test-backend-part-2
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -67,7 +67,7 @@ jobs:
...
@@ -67,7 +67,7 @@ jobs:
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin
5
--range-end 1
4
python3 run_suite.py --suite minimal --range-begin
6
--range-end 1
5
unit-test-backend-part-3
:
unit-test-backend-part-3
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -84,7 +84,7 @@ jobs:
...
@@ -84,7 +84,7 @@ jobs:
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 1
4
--range-end 2
3
python3 run_suite.py --suite minimal --range-begin 1
5
--range-end 2
4
unit-test-backend-part-4
:
unit-test-backend-part-4
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -101,7 +101,7 @@ jobs:
...
@@ -101,7 +101,7 @@ jobs:
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 2
3
python3 run_suite.py --suite minimal --range-begin 2
4
unit-test-backend-2-gpu-part-1
:
unit-test-backend-2-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b2ccf36d
...
@@ -231,6 +231,7 @@ class Req:
...
@@ -231,6 +231,7 @@ class Req:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
self
.
stream
=
False
self
.
stream
=
False
self
.
to_abort
=
False
# For incremental decoding
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | --------- read_ids -------|
...
@@ -368,6 +369,10 @@ class Req:
...
@@ -368,6 +369,10 @@ class Req:
if
self
.
finished
():
if
self
.
finished
():
return
return
if
self
.
to_abort
:
self
.
finished_reason
=
FINISH_ABORT
()
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
self
.
finished_reason
=
FINISH_LENGTH
(
self
.
finished_reason
=
FINISH_LENGTH
(
length
=
self
.
sampling_params
.
max_new_tokens
length
=
self
.
sampling_params
.
max_new_tokens
...
...
python/sglang/srt/managers/scheduler.py
View file @
b2ccf36d
...
@@ -579,6 +579,8 @@ class Scheduler:
...
@@ -579,6 +579,8 @@ class Scheduler:
"Image request length is longer than the KV cache pool size or "
"Image request length is longer than the KV cache pool size or "
"the max context length aborting because you cannot truncate the image embeds"
"the max context length aborting because you cannot truncate the image embeds"
)
)
req
.
image_inputs
=
None
req
.
origin_input_ids
=
[
0
]
req
.
sampling_params
.
max_new_tokens
=
0
req
.
sampling_params
.
max_new_tokens
=
0
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
return
return
...
@@ -1350,13 +1352,15 @@ class Scheduler:
...
@@ -1350,13 +1352,15 @@ class Scheduler:
if
to_del
is
not
None
:
if
to_del
is
not
None
:
del
self
.
waiting_queue
[
to_del
]
del
self
.
waiting_queue
[
to_del
]
logger
.
debug
(
f
"Abort queued request.
{
req
.
rid
=
}
"
)
return
# Delete requests in the running batch
# Delete requests in the running batch
if
self
.
running_batch
:
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
and
not
req
.
finished
():
if
req
.
rid
==
recv_req
.
rid
and
not
req
.
finished
():
req
.
finished_reason
=
FINISH_ABORT
(
)
logger
.
debug
(
f
"Abort running request.
{
req
.
rid
=
}
"
)
self
.
tree_cache
.
cache_finished_req
(
req
)
req
.
to_abort
=
True
break
break
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
def
update_weights
(
self
,
recv_req
:
UpdateWeightReqInput
):
...
...
python/sglang/test/test_utils.py
View file @
b2ccf36d
...
@@ -677,8 +677,14 @@ def run_and_check_memory_leak(
...
@@ -677,8 +677,14 @@ def run_and_check_memory_leak(
enable_mixed_chunk
,
enable_mixed_chunk
,
disable_overlap
,
disable_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
assert_has_abort
,
):
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
),
"--log-level"
,
"debug"
,
]
if
disable_radix_cache
:
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
if
enable_mixed_chunk
:
...
@@ -723,14 +729,19 @@ def run_and_check_memory_leak(
...
@@ -723,14 +729,19 @@ def run_and_check_memory_leak(
# Assert success
# Assert success
has_new_server
=
False
has_new_server
=
False
has_leak
=
False
has_leak
=
False
has_abort
=
False
for
line
in
output_lines
:
for
line
in
output_lines
:
if
"The server is fired"
in
line
:
if
"The server is fired"
in
line
:
has_new_server
=
True
has_new_server
=
True
if
"leak"
in
line
:
if
"leak"
in
line
:
has_leak
=
True
has_leak
=
True
if
"Abort"
in
line
:
has_abort
=
True
assert
has_new_server
assert
has_new_server
assert
not
has_leak
assert
not
has_leak
if
assert_has_abort
:
assert
has_abort
def
run_mmlu_test
(
def
run_mmlu_test
(
...
@@ -761,6 +772,7 @@ def run_mmlu_test(
...
@@ -761,6 +772,7 @@ def run_mmlu_test(
enable_mixed_chunk
,
enable_mixed_chunk
,
disable_overlap
,
disable_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
assert_has_abort
=
False
,
)
)
...
@@ -800,4 +812,5 @@ def run_mulit_request_test(
...
@@ -800,4 +812,5 @@ def run_mulit_request_test(
enable_mixed_chunk
,
enable_mixed_chunk
,
enable_overlap
,
enable_overlap
,
chunked_prefill_size
,
chunked_prefill_size
,
assert_has_abort
=
False
,
)
)
test/srt/run_suite.py
View file @
b2ccf36d
...
@@ -10,6 +10,7 @@ suites = {
...
@@ -10,6 +10,7 @@ suites = {
"models/test_lora.py"
,
"models/test_lora.py"
,
"models/test_reward_models.py"
,
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"sampling/penaltylib"
,
"test_abort.py"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
"test_double_sparsity.py"
,
"test_embedding_openai_server.py"
,
"test_embedding_openai_server.py"
,
...
...
test/srt/test_abort.py
0 → 100644
View file @
b2ccf36d
import
multiprocessing
import
time
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
import
requests
from
sglang.test.test_utils
import
run_and_check_memory_leak
class
TestAbort
(
unittest
.
TestCase
):
def
workload_func
(
self
,
base_url
,
model
):
def
process_func
():
def
run_one
(
_
):
prompt
=
"""
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response
=
requests
.
post
(
f
"
{
base_url
}
/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
2048
,
},
},
)
ret
=
response
.
json
()
with
ThreadPoolExecutor
(
16
)
as
executor
:
list
(
executor
.
map
(
run_one
,
list
(
range
(
16
))))
p
=
multiprocessing
.
Process
(
target
=
process_func
)
p
.
start
()
time
.
sleep
(
0.5
)
p
.
terminate
()
time
.
sleep
(
10
)
def
test_memory_leak
(
self
):
run_and_check_memory_leak
(
self
.
workload_func
,
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
disable_overlap
=
False
,
chunked_prefill_size
=
8192
,
assert_has_abort
=
True
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment