Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c555ce2c
Unverified
Commit
c555ce2c
authored
Oct 25, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 25, 2024
Browse files
Revert "Fix memory leak when doing chunked prefill" (#1797)
parent
40900bae
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
183 deletions
+69
-183
python/sglang/global_config.py
python/sglang/global_config.py
+1
-11
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-3
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+7
-18
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+57
-38
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+0
-1
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+0
-112
No files found.
python/sglang/global_config.py
View file @
c555ce2c
...
...
@@ -15,7 +15,7 @@ class GlobalConfig:
# Runtime constants: New generation token ratio estimation
self
.
init_new_token_ratio
=
0.7
self
.
min_new_token_ratio
=
0.1
self
.
base_
min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
...
...
@@ -32,15 +32,5 @@ class GlobalConfig:
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
def
adjust_new_token_ratio
(
self
,
schedule_conservativeness
=
1
):
assert
schedule_conservativeness
>=
0
,
"Invalid schedule_conservativeness"
min_new_token_ratio
=
min
(
self
.
min_new_token_ratio
*
schedule_conservativeness
,
1.0
,
)
init_new_token_ratio
=
max
(
self
.
init_new_token_ratio
,
min_new_token_ratio
)
return
min_new_token_ratio
,
init_new_token_ratio
global_config
=
GlobalConfig
()
python/sglang/srt/managers/schedule_batch.py
View file @
c555ce2c
...
...
@@ -222,7 +222,7 @@ class Req:
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
is_
being_chunked
=
False
self
.
is_
inflight_req
=
0
# Logprobs (arguments)
self
.
return_logprob
=
False
...
...
@@ -906,14 +906,15 @@ class ScheduleBatch:
def
filter_batch
(
self
,
being_chunked
_req
:
Optional
[
Req
]
=
None
,
current_inflight
_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
if
keep_indices
is
None
:
keep_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
being_chunked_req
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
current_inflight_req
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
c555ce2c
...
...
@@ -136,7 +136,7 @@ class PrefillAdder:
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
new_
chunked
_req
=
None
self
.
new_
inflight
_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
...
...
@@ -176,7 +176,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
def
add_
being_chunked
_req
(
self
,
req
:
Req
):
def
add_
inflight
_req
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
...
@@ -192,13 +192,8 @@ class PrefillAdder:
),
)
if
truncated
:
# Continue to chunk the request
assert
req
.
is_being_chunked
self
.
new_chunked_req
=
req
else
:
# Release the being chunked status
req
.
is_being_chunked
=
False
# Return if chunked prefill not finished
return
req
if
truncated
else
None
@
contextmanager
def
_lock_node
(
self
,
last_node
:
TreeNode
):
...
...
@@ -267,14 +262,11 @@ class PrefillAdder:
)
else
:
# Chunked prefill
assert
self
.
new_chunked_req
is
None
trunc_len
=
self
.
rem_chunk_tokens
req
.
extend_input_len
=
trunc_len
req
.
is_being_chunked
=
True
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
new_
chunked
_req
=
req
self
.
new_
inflight
_req
=
req
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
return
self
.
budget_state
()
...
...
@@ -313,18 +305,15 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
)
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
# Chunked prefill
assert
self
.
new_chunked_req
is
None
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
is_being_chunked
=
True
self
.
can_run_list
.
append
(
req
)
self
.
new_
chunked
_req
=
req
self
.
new_
inflight
_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
c555ce2c
...
...
@@ -219,28 +219,35 @@ class Scheduler:
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
being_chunked
_req
=
None
self
.
current_inflight
_req
=
None
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
# Init the FSM cache for constrained generation
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
self
.
min_new_token_ratio
,
self
.
init_new_token_ratio
=
(
global_config
.
adjust_new_token_ratio
(
server_args
.
schedule_conservativeness
)
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
# Init profiler
...
...
@@ -287,7 +294,7 @@ class Scheduler:
self
.
process_batch_result
(
batch
,
result
)
else
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
last_batch
=
batch
...
...
@@ -314,7 +321,7 @@ class Scheduler:
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
last_batch
=
batch
...
...
@@ -492,18 +499,20 @@ class Scheduler:
)
exit
(
1
)
if
crash_on_warning
else
None
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
def
get_next_batch_to_run
(
self
):
# Merge the prefill batch into the running batch
if
(
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
):
if
self
.
being_chunked_req
:
self
.
last_batch
.
filter_batch
(
being_chunked_req
=
self
.
being_chunked_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
being_chunked_req
)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
being_chunked_req
.
req_pool_idx
)
if
self
.
current_inflight_req
:
self
.
last_batch
.
filter_batch
(
current_inflight_req
=
self
.
current_inflight_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
current_inflight_req
)
# Inflight request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
current_inflight_req
.
req_pool_idx
)
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
...
...
@@ -534,7 +543,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
being_chunked
_req
is
None
:
)
and
self
.
current_inflight
_req
is
None
:
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
...
...
@@ -557,6 +566,15 @@ class Scheduler:
num_mixed_running
,
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
has_inflight
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
:
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
...
...
@@ -564,13 +582,6 @@ class Scheduler:
else
set
([])
)
# NOTE: if there is request being chunked, we always add it first
has_being_chunked
=
self
.
being_chunked_req
is
not
None
if
has_being_chunked
:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self
.
being_chunked_req
.
init_next_round_input
()
adder
.
add_being_chunked_req
(
self
.
being_chunked_req
)
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
if
(
...
...
@@ -604,8 +615,12 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
# Update new round being chunked request
self
.
being_chunked_req
=
adder
.
new_chunked_req
if
adder
.
new_inflight_req
is
not
None
:
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
self
.
current_inflight_req
:
self
.
current_inflight_req
.
is_inflight_req
+=
1
# Print stats
if
self
.
tp_rank
==
0
:
...
...
@@ -634,7 +649,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
)
else
:
logger
.
info
(
...
...
@@ -645,7 +660,7 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
)
# Create a new batch
...
...
@@ -694,7 +709,7 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
global_config
.
new_token_ratio_decay
,
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
)
...
...
@@ -768,8 +783,10 @@ class Scheduler:
# Check finish conditions
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
is_being_chunked
:
# Being chunked reqs' prefill is not finished
if
req
.
is_inflight_req
>
0
:
req
.
is_inflight_req
-=
1
else
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
...
...
@@ -795,8 +812,10 @@ class Scheduler:
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
if
not
req
.
is_being_chunked
:
# Being chunked reqs' prefill is not finished
if
req
.
is_inflight_req
>
0
:
req
.
is_inflight_req
-=
1
else
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
...
...
python/sglang/test/test_utils.py
View file @
c555ce2c
...
...
@@ -663,7 +663,6 @@ def run_mmlu_test(
chunked_prefill_size
=
32
,
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
+=
[
"--mem-fraction-static"
,
"0.85"
]
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
...
...
test/srt/test_radix_attention.py
deleted
100644 → 0
View file @
40900bae
import
os
import
random
import
unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
popen_launch_server
,
)
def
gen_radix_tree
(
num_nodes
=
400
,
chunk_len
=
256
):
num0
=
num_nodes
//
2
num1
=
num_nodes
-
num0
nodes
=
[{
"input_ids"
:
[
37
]
*
117
,
"decode_len"
:
217
}]
for
_
in
range
(
num0
):
parent
=
random
.
choice
(
nodes
)
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
while
num1
>
0
:
num_branch
=
random
.
randint
(
1
,
min
(
num1
,
10
))
parent
=
random
.
choice
(
nodes
)
for
_
in
range
(
num_branch
):
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
num1
-=
num_branch
random
.
shuffle
(
nodes
)
return
nodes
def
run_test
(
base_url
,
nodes
):
data
=
{
"input_ids"
:
[
node
[
"input_ids"
]
for
node
in
nodes
],
"sampling_params"
:
[
{
"max_new_tokens"
:
node
[
"decode_len"
],
"temperature"
:
0
}
for
node
in
nodes
],
}
res
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"fcfs"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_radix_attention
(
self
):
nodes
=
gen_radix_tree
()
run_test
(
self
.
base_url
,
nodes
)
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"lpm"
,
],
)
if
__name__
==
"__main__"
:
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment