Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2f5e755
Unverified
Commit
a2f5e755
authored
Oct 25, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 25, 2024
Browse files
Fix memory leak when doing chunked prefill (#1787)
parent
2148914e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
184 additions
and
69 deletions
+184
-69
python/sglang/global_config.py
python/sglang/global_config.py
+11
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+18
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+38
-57
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+112
-0
No files found.
python/sglang/global_config.py
View file @
a2f5e755
...
...
@@ -15,7 +15,7 @@ class GlobalConfig:
# Runtime constants: New generation token ratio estimation
self
.
init_new_token_ratio
=
0.7
self
.
base_
min_new_token_ratio
=
0.1
self
.
min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
...
...
@@ -32,5 +32,15 @@ class GlobalConfig:
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
def
adjust_new_token_ratio
(
self
,
schedule_conservativeness
=
1
):
assert
schedule_conservativeness
>=
0
,
"Invalid schedule_conservativeness"
min_new_token_ratio
=
min
(
self
.
min_new_token_ratio
*
schedule_conservativeness
,
1.0
,
)
init_new_token_ratio
=
max
(
self
.
init_new_token_ratio
,
min_new_token_ratio
)
return
min_new_token_ratio
,
init_new_token_ratio
global_config
=
GlobalConfig
()
python/sglang/srt/managers/schedule_batch.py
View file @
a2f5e755
...
...
@@ -222,7 +222,7 @@ class Req:
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
is_
inflight_req
=
0
self
.
is_
being_chunked
=
False
# Logprobs (arguments)
self
.
return_logprob
=
False
...
...
@@ -906,15 +906,14 @@ class ScheduleBatch:
def
filter_batch
(
self
,
current_inflight
_req
:
Optional
[
Req
]
=
None
,
being_chunked
_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
if
keep_indices
is
None
:
keep_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
current_inflight_req
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
being_chunked_req
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
a2f5e755
...
...
@@ -136,7 +136,7 @@ class PrefillAdder:
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
new_
inflight
_req
=
None
self
.
new_
chunked
_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
...
...
@@ -176,7 +176,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
def
add_
inflight
_req
(
self
,
req
:
Req
):
def
add_
being_chunked
_req
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
...
@@ -192,8 +192,13 @@ class PrefillAdder:
),
)
# Return if chunked prefill not finished
return
req
if
truncated
else
None
if
truncated
:
# Continue to chunk the request
assert
req
.
is_being_chunked
self
.
new_chunked_req
=
req
else
:
# Release the being chunked status
req
.
is_being_chunked
=
False
@
contextmanager
def
_lock_node
(
self
,
last_node
:
TreeNode
):
...
...
@@ -262,11 +267,14 @@ class PrefillAdder:
)
else
:
# Chunked prefill
assert
self
.
new_chunked_req
is
None
trunc_len
=
self
.
rem_chunk_tokens
req
.
extend_input_len
=
trunc_len
req
.
is_being_chunked
=
True
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
new_
inflight
_req
=
req
self
.
new_
chunked
_req
=
req
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
return
self
.
budget_state
()
...
...
@@ -305,15 +313,18 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
)
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
# Chunked prefill
assert
self
.
new_chunked_req
is
None
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
is_being_chunked
=
True
self
.
can_run_list
.
append
(
req
)
self
.
new_
inflight
_req
=
req
self
.
new_
chunked
_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
a2f5e755
...
...
@@ -219,35 +219,28 @@ class Scheduler:
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight
_req
=
None
self
.
being_chunked
_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
self
.
min_new_token_ratio
,
self
.
init_new_token_ratio
=
(
global_config
.
adjust_new_token_ratio
(
server_args
.
schedule_conservativeness
)
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
batch_is_full
=
False
# Init profiler
...
...
@@ -294,7 +287,7 @@ class Scheduler:
self
.
process_batch_result
(
batch
,
result
)
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
...
...
@@ -321,7 +314,7 @@ class Scheduler:
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
...
...
@@ -499,20 +492,18 @@ class Scheduler:
)
exit
(
1
)
if
crash_on_warning
else
None
def
get_next_batch_to_run
(
self
):
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
# Merge the prefill batch into the running batch
if
(
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
current_inflight_req
:
self
.
last_batch
.
filter_batch
(
current_inflight_req
=
self
.
current_inflight_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
current_inflight_req
)
# Inflight request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
current_inflight_req
.
req_pool_idx
)
if
self
.
being_chunked_req
:
self
.
last_batch
.
filter_batch
(
being_chunked_req
=
self
.
being_chunked_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
being_chunked_req
)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
being_chunked_req
.
req_pool_idx
)
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
...
...
@@ -543,7 +534,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight
_req
is
None
:
)
and
self
.
being_chunked
_req
is
None
:
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
...
...
@@ -566,15 +557,6 @@ class Scheduler:
num_mixed_running
,
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
has_inflight
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
:
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
...
...
@@ -582,6 +564,13 @@ class Scheduler:
else
set
([])
)
# NOTE: if there is request being chunked, we always add it first
has_being_chunked
=
self
.
being_chunked_req
is
not
None
if
has_being_chunked
:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self
.
being_chunked_req
.
init_next_round_input
()
adder
.
add_being_chunked_req
(
self
.
being_chunked_req
)
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
if
(
...
...
@@ -615,12 +604,8 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
self
.
current_inflight_req
:
self
.
current_inflight_req
.
is_inflight_req
+=
1
# Update new round being chunked request
self
.
being_chunked_req
=
adder
.
new_chunked_req
# Print stats
if
self
.
tp_rank
==
0
:
...
...
@@ -649,7 +634,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
)
else
:
logger
.
info
(
...
...
@@ -660,7 +645,7 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
)
# Create a new batch
...
...
@@ -709,7 +694,7 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
new_token_ratio
-
global_config
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
...
...
@@ -783,10 +768,8 @@ class Scheduler:
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
is_inflight_req
>
0
:
req
.
is_inflight_req
-=
1
else
:
# Inflight reqs' prefill is not finished
if
not
req
.
is_being_chunked
:
# Being chunked reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
...
...
@@ -812,10 +795,8 @@ class Scheduler:
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
req
.
is_inflight_req
>
0
:
req
.
is_inflight_req
-=
1
else
:
# Inflight reqs' prefill is not finished
if
not
req
.
is_being_chunked
:
# Being chunked reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
...
...
python/sglang/test/test_utils.py
View file @
a2f5e755
...
...
@@ -660,6 +660,7 @@ def run_mmlu_test(
chunked_prefill_size
=
32
,
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
+=
[
"--mem-fraction-static"
,
"0.85"
]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
...
...
test/srt/run_suite.py
View file @
a2f5e755
...
...
@@ -5,6 +5,7 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
"minimal"
:
[
"test_radix_attention.py"
,
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
...
...
test/srt/test_radix_attention.py
0 → 100644
View file @
a2f5e755
import
os
import
random
import
unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
popen_launch_server
,
)
def
gen_radix_tree
(
num_nodes
=
400
,
chunk_len
=
256
):
num0
=
num_nodes
//
2
num1
=
num_nodes
-
num0
nodes
=
[{
"input_ids"
:
[
37
]
*
117
,
"decode_len"
:
217
}]
for
_
in
range
(
num0
):
parent
=
random
.
choice
(
nodes
)
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
while
num1
>
0
:
num_branch
=
random
.
randint
(
1
,
min
(
num1
,
10
))
parent
=
random
.
choice
(
nodes
)
for
_
in
range
(
num_branch
):
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
num1
-=
num_branch
random
.
shuffle
(
nodes
)
return
nodes
def
run_test
(
base_url
,
nodes
):
data
=
{
"input_ids"
:
[
node
[
"input_ids"
]
for
node
in
nodes
],
"sampling_params"
:
[
{
"max_new_tokens"
:
node
[
"decode_len"
],
"temperature"
:
0
}
for
node
in
nodes
],
}
res
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"fcfs"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_radix_attention
(
self
):
nodes
=
gen_radix_tree
()
run_test
(
self
.
base_url
,
nodes
)
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"lpm"
,
],
)
if
__name__
==
"__main__"
:
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment