Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2f5e755
Unverified
Commit
a2f5e755
authored
Oct 25, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 25, 2024
Browse files
Fix memory leak when doing chunked prefill (#1787)
parent
2148914e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
184 additions
and
69 deletions
+184
-69
python/sglang/global_config.py
python/sglang/global_config.py
+11
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+18
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+38
-57
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+112
-0
No files found.
python/sglang/global_config.py
View file @
a2f5e755
...
@@ -15,7 +15,7 @@ class GlobalConfig:
...
@@ -15,7 +15,7 @@ class GlobalConfig:
# Runtime constants: New generation token ratio estimation
# Runtime constants: New generation token ratio estimation
self
.
init_new_token_ratio
=
0.7
self
.
init_new_token_ratio
=
0.7
self
.
base_
min_new_token_ratio
=
0.1
self
.
min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
# Runtime constants: others
...
@@ -32,5 +32,15 @@ class GlobalConfig:
...
@@ -32,5 +32,15 @@ class GlobalConfig:
self
.
enable_precache_with_tracing
=
True
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_encoding
=
True
def
adjust_new_token_ratio
(
self
,
schedule_conservativeness
=
1
):
assert
schedule_conservativeness
>=
0
,
"Invalid schedule_conservativeness"
min_new_token_ratio
=
min
(
self
.
min_new_token_ratio
*
schedule_conservativeness
,
1.0
,
)
init_new_token_ratio
=
max
(
self
.
init_new_token_ratio
,
min_new_token_ratio
)
return
min_new_token_ratio
,
init_new_token_ratio
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/srt/managers/schedule_batch.py
View file @
a2f5e755
...
@@ -222,7 +222,7 @@ class Req:
...
@@ -222,7 +222,7 @@ class Req:
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
self
.
is_
inflight_req
=
0
self
.
is_
being_chunked
=
False
# Logprobs (arguments)
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
return_logprob
=
False
...
@@ -906,15 +906,14 @@ class ScheduleBatch:
...
@@ -906,15 +906,14 @@ class ScheduleBatch:
def
filter_batch
(
def
filter_batch
(
self
,
self
,
current_inflight
_req
:
Optional
[
Req
]
=
None
,
being_chunked
_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
):
if
keep_indices
is
None
:
if
keep_indices
is
None
:
keep_indices
=
[
keep_indices
=
[
i
i
for
i
in
range
(
len
(
self
.
reqs
))
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
being_chunked_req
and
self
.
reqs
[
i
]
is
not
current_inflight_req
]
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
a2f5e755
...
@@ -136,7 +136,7 @@ class PrefillAdder:
...
@@ -136,7 +136,7 @@ class PrefillAdder:
self
.
req_states
=
None
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_
inflight
_req
=
None
self
.
new_
chunked
_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
self
.
log_input_tokens
=
0
...
@@ -176,7 +176,7 @@ class PrefillAdder:
...
@@ -176,7 +176,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
self
.
log_input_tokens
+=
extend_input_len
def
add_
inflight
_req
(
self
,
req
:
Req
):
def
add_
being_chunked
_req
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
@@ -192,8 +192,13 @@ class PrefillAdder:
...
@@ -192,8 +192,13 @@ class PrefillAdder:
),
),
)
)
# Return if chunked prefill not finished
if
truncated
:
return
req
if
truncated
else
None
# Continue to chunk the request
assert
req
.
is_being_chunked
self
.
new_chunked_req
=
req
else
:
# Release the being chunked status
req
.
is_being_chunked
=
False
@
contextmanager
@
contextmanager
def
_lock_node
(
self
,
last_node
:
TreeNode
):
def
_lock_node
(
self
,
last_node
:
TreeNode
):
...
@@ -262,11 +267,14 @@ class PrefillAdder:
...
@@ -262,11 +267,14 @@ class PrefillAdder:
)
)
else
:
else
:
# Chunked prefill
# Chunked prefill
assert
self
.
new_chunked_req
is
None
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
is_being_chunked
=
True
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_
inflight
_req
=
req
self
.
new_
chunked
_req
=
req
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
return
self
.
budget_state
()
return
self
.
budget_state
()
...
@@ -305,15 +313,18 @@ class PrefillAdder:
...
@@ -305,15 +313,18 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
)
)
else
:
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
return
AddReqResult
.
OTHER
# Chunked prefill
assert
self
.
new_chunked_req
is
None
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
is_being_chunked
=
True
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_
inflight
_req
=
req
self
.
new_
chunked
_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
a2f5e755
...
@@ -219,35 +219,28 @@ class Scheduler:
...
@@ -219,35 +219,28 @@ class Scheduler:
# Init chunked prefill
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight
_req
=
None
self
.
being_chunked
_req
=
None
self
.
is_mixed_chunk
=
(
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
)
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
server_args
.
tokenizer_path
,
{
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
constrained_json_whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
)
)
self
.
jump_forward_cache
=
JumpForwardCache
()
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
# Init new token estimation
assert
(
self
.
min_new_token_ratio
,
self
.
init_new_token_ratio
=
(
server_args
.
schedule_conservativeness
>=
0
global_config
.
adjust_new_token_ratio
(
server_args
.
schedule_conservativeness
)
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
# Init profiler
# Init profiler
...
@@ -294,7 +287,7 @@ class Scheduler:
...
@@ -294,7 +287,7 @@ class Scheduler:
self
.
process_batch_result
(
batch
,
result
)
self
.
process_batch_result
(
batch
,
result
)
else
:
else
:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -321,7 +314,7 @@ class Scheduler:
...
@@ -321,7 +314,7 @@ class Scheduler:
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
elif
batch
is
None
:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -499,20 +492,18 @@ class Scheduler:
...
@@ -499,20 +492,18 @@ class Scheduler:
)
)
exit
(
1
)
if
crash_on_warning
else
None
exit
(
1
)
if
crash_on_warning
else
None
def
get_next_batch_to_run
(
self
):
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
# Merge the prefill batch into the running batch
# Merge the prefill batch into the running batch
if
(
if
(
self
.
last_batch
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
and
not
self
.
last_batch
.
is_empty
()
):
):
if
self
.
current_inflight_req
:
if
self
.
being_chunked_req
:
self
.
last_batch
.
filter_batch
(
self
.
last_batch
.
filter_batch
(
being_chunked_req
=
self
.
being_chunked_req
)
current_inflight_req
=
self
.
current_inflight_req
self
.
tree_cache
.
cache_unfinished_req
(
self
.
being_chunked_req
)
)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self
.
tree_cache
.
cache_unfinished_req
(
self
.
current_inflight_req
)
self
.
req_to_token_pool
.
free
(
self
.
being_chunked_req
.
req_pool_idx
)
# Inflight request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
current_inflight_req
.
req_pool_idx
)
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
if
self
.
running_batch
is
None
:
...
@@ -543,7 +534,7 @@ class Scheduler:
...
@@ -543,7 +534,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
# Handle the cases where prefill is not allowed
if
(
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight
_req
is
None
:
)
and
self
.
being_chunked
_req
is
None
:
return
None
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
...
@@ -566,15 +557,6 @@ class Scheduler:
...
@@ -566,15 +557,6 @@ class Scheduler:
num_mixed_running
,
num_mixed_running
,
)
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
has_inflight
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
:
if
self
.
lora_paths
:
lora_set
=
(
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
...
@@ -582,6 +564,13 @@ class Scheduler:
...
@@ -582,6 +564,13 @@ class Scheduler:
else
set
([])
else
set
([])
)
)
# NOTE: if there is request being chunked, we always add it first
has_being_chunked
=
self
.
being_chunked_req
is
not
None
if
has_being_chunked
:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self
.
being_chunked_req
.
init_next_round_input
()
adder
.
add_being_chunked_req
(
self
.
being_chunked_req
)
# Get requests from the waiting queue to a new prefill batch
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
(
if
(
...
@@ -615,12 +604,8 @@ class Scheduler:
...
@@ -615,12 +604,8 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
]
if
adder
.
new_inflight_req
is
not
None
:
# Update new round being chunked request
assert
self
.
current_inflight_req
is
None
self
.
being_chunked_req
=
adder
.
new_chunked_req
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
self
.
current_inflight_req
:
self
.
current_inflight_req
.
is_inflight_req
+=
1
# Print stats
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
...
@@ -649,7 +634,7 @@ class Scheduler:
...
@@ -649,7 +634,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
)
)
else
:
else
:
logger
.
info
(
logger
.
info
(
...
@@ -660,7 +645,7 @@ class Scheduler:
...
@@ -660,7 +645,7 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
)
)
# Create a new batch
# Create a new batch
...
@@ -709,7 +694,7 @@ class Scheduler:
...
@@ -709,7 +694,7 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
retracted_reqs
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
new_token_ratio
-
global_config
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
self
.
min_new_token_ratio
,
)
)
...
@@ -783,10 +768,8 @@ class Scheduler:
...
@@ -783,10 +768,8 @@ class Scheduler:
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
is_inflight_req
>
0
:
if
not
req
.
is_being_chunked
:
req
.
is_inflight_req
-=
1
# Being chunked reqs' prefill is not finished
else
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
req
.
check_finished
()
...
@@ -812,10 +795,8 @@ class Scheduler:
...
@@ -812,10 +795,8 @@ class Scheduler:
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
req
.
embedding
=
embeddings
[
i
]
if
req
.
is_inflight_req
>
0
:
if
not
req
.
is_being_chunked
:
req
.
is_inflight_req
-=
1
# Being chunked reqs' prefill is not finished
else
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
req
.
check_finished
()
...
...
python/sglang/test/test_utils.py
View file @
a2f5e755
...
@@ -660,6 +660,7 @@ def run_mmlu_test(
...
@@ -660,6 +660,7 @@ def run_mmlu_test(
chunked_prefill_size
=
32
,
chunked_prefill_size
=
32
,
):
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
+=
[
"--mem-fraction-static"
,
"0.85"
]
if
disable_radix_cache
:
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
if
enable_mixed_chunk
:
...
...
test/srt/run_suite.py
View file @
a2f5e755
...
@@ -5,6 +5,7 @@ from sglang.test.test_utils import run_unittest_files
...
@@ -5,6 +5,7 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
suites
=
{
"minimal"
:
[
"minimal"
:
[
"test_radix_attention.py"
,
"models/test_embedding_models.py"
,
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
"models/test_lora.py"
,
...
...
test/srt/test_radix_attention.py
0 → 100644
View file @
a2f5e755
import
os
import
random
import
unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
popen_launch_server
,
)
def
gen_radix_tree
(
num_nodes
=
400
,
chunk_len
=
256
):
num0
=
num_nodes
//
2
num1
=
num_nodes
-
num0
nodes
=
[{
"input_ids"
:
[
37
]
*
117
,
"decode_len"
:
217
}]
for
_
in
range
(
num0
):
parent
=
random
.
choice
(
nodes
)
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
while
num1
>
0
:
num_branch
=
random
.
randint
(
1
,
min
(
num1
,
10
))
parent
=
random
.
choice
(
nodes
)
for
_
in
range
(
num_branch
):
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
num1
-=
num_branch
random
.
shuffle
(
nodes
)
return
nodes
def
run_test
(
base_url
,
nodes
):
data
=
{
"input_ids"
:
[
node
[
"input_ids"
]
for
node
in
nodes
],
"sampling_params"
:
[
{
"max_new_tokens"
:
node
[
"decode_len"
],
"temperature"
:
0
}
for
node
in
nodes
],
}
res
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"fcfs"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_radix_attention
(
self
):
nodes
=
gen_radix_tree
()
run_test
(
self
.
base_url
,
nodes
)
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"lpm"
,
],
)
if
__name__
==
"__main__"
:
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment