Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2f5e755
"docs/vscode:/vscode.git/clone" did not exist on "88735249da94266a433368d2b899e87dc33446c9"
Unverified
Commit
a2f5e755
authored
Oct 25, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 25, 2024
Browse files
Fix memory leak when doing chunked prefill (#1787)
parent
2148914e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
184 additions
and
69 deletions
+184
-69
python/sglang/global_config.py
python/sglang/global_config.py
+11
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+18
-7
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+38
-57
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+112
-0
No files found.
python/sglang/global_config.py
View file @
a2f5e755
...
@@ -15,7 +15,7 @@ class GlobalConfig:
...
@@ -15,7 +15,7 @@ class GlobalConfig:
# Runtime constants: New generation token ratio estimation
# Runtime constants: New generation token ratio estimation
self
.
init_new_token_ratio
=
0.7
self
.
init_new_token_ratio
=
0.7
self
.
base_
min_new_token_ratio
=
0.1
self
.
min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
# Runtime constants: others
...
@@ -32,5 +32,15 @@ class GlobalConfig:
...
@@ -32,5 +32,15 @@ class GlobalConfig:
self
.
enable_precache_with_tracing
=
True
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_encoding
=
True
def
adjust_new_token_ratio
(
self
,
schedule_conservativeness
=
1
):
assert
schedule_conservativeness
>=
0
,
"Invalid schedule_conservativeness"
min_new_token_ratio
=
min
(
self
.
min_new_token_ratio
*
schedule_conservativeness
,
1.0
,
)
init_new_token_ratio
=
max
(
self
.
init_new_token_ratio
,
min_new_token_ratio
)
return
min_new_token_ratio
,
init_new_token_ratio
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/srt/managers/schedule_batch.py
View file @
a2f5e755
...
@@ -222,7 +222,7 @@ class Req:
...
@@ -222,7 +222,7 @@ class Req:
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
self
.
is_
inflight_req
=
0
self
.
is_
being_chunked
=
False
# Logprobs (arguments)
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
return_logprob
=
False
...
@@ -906,15 +906,14 @@ class ScheduleBatch:
...
@@ -906,15 +906,14 @@ class ScheduleBatch:
def
filter_batch
(
def
filter_batch
(
self
,
self
,
current_inflight
_req
:
Optional
[
Req
]
=
None
,
being_chunked
_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
):
if
keep_indices
is
None
:
if
keep_indices
is
None
:
keep_indices
=
[
keep_indices
=
[
i
i
for
i
in
range
(
len
(
self
.
reqs
))
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
being_chunked_req
and
self
.
reqs
[
i
]
is
not
current_inflight_req
]
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
a2f5e755
...
@@ -136,7 +136,7 @@ class PrefillAdder:
...
@@ -136,7 +136,7 @@ class PrefillAdder:
self
.
req_states
=
None
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_
inflight
_req
=
None
self
.
new_
chunked
_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
self
.
log_input_tokens
=
0
...
@@ -176,7 +176,7 @@ class PrefillAdder:
...
@@ -176,7 +176,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
self
.
log_input_tokens
+=
extend_input_len
def
add_
inflight
_req
(
self
,
req
:
Req
):
def
add_
being_chunked
_req
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
@@ -192,8 +192,13 @@ class PrefillAdder:
...
@@ -192,8 +192,13 @@ class PrefillAdder:
),
),
)
)
# Return if chunked prefill not finished
if
truncated
:
return
req
if
truncated
else
None
# Continue to chunk the request
assert
req
.
is_being_chunked
self
.
new_chunked_req
=
req
else
:
# Release the being chunked status
req
.
is_being_chunked
=
False
@
contextmanager
@
contextmanager
def
_lock_node
(
self
,
last_node
:
TreeNode
):
def
_lock_node
(
self
,
last_node
:
TreeNode
):
...
@@ -262,11 +267,14 @@ class PrefillAdder:
...
@@ -262,11 +267,14 @@ class PrefillAdder:
)
)
else
:
else
:
# Chunked prefill
# Chunked prefill
assert
self
.
new_chunked_req
is
None
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
is_being_chunked
=
True
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_
inflight
_req
=
req
self
.
new_
chunked
_req
=
req
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
return
self
.
budget_state
()
return
self
.
budget_state
()
...
@@ -305,15 +313,18 @@ class PrefillAdder:
...
@@ -305,15 +313,18 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
)
)
else
:
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
return
AddReqResult
.
OTHER
# Chunked prefill
assert
self
.
new_chunked_req
is
None
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
is_being_chunked
=
True
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_
inflight
_req
=
req
self
.
new_
chunked
_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
a2f5e755
...
@@ -219,13 +219,12 @@ class Scheduler:
...
@@ -219,13 +219,12 @@ class Scheduler:
# Init chunked prefill
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight
_req
=
None
self
.
being_chunked
_req
=
None
self
.
is_mixed_chunk
=
(
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
)
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
server_args
.
tokenizer_path
,
{
{
...
@@ -238,16 +237,10 @@ class Scheduler:
...
@@ -238,16 +237,10 @@ class Scheduler:
self
.
jump_forward_cache
=
JumpForwardCache
()
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
# Init new token estimation
assert
(
self
.
min_new_token_ratio
,
self
.
init_new_token_ratio
=
(
server_args
.
schedule_conservativeness
>=
0
global_config
.
adjust_new_token_ratio
(
server_args
.
schedule_conservativeness
)
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
)
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
# Init profiler
# Init profiler
...
@@ -294,7 +287,7 @@ class Scheduler:
...
@@ -294,7 +287,7 @@ class Scheduler:
self
.
process_batch_result
(
batch
,
result
)
self
.
process_batch_result
(
batch
,
result
)
else
:
else
:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -321,7 +314,7 @@ class Scheduler:
...
@@ -321,7 +314,7 @@ class Scheduler:
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
elif
batch
is
None
:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -499,20 +492,18 @@ class Scheduler:
...
@@ -499,20 +492,18 @@ class Scheduler:
)
)
exit
(
1
)
if
crash_on_warning
else
None
exit
(
1
)
if
crash_on_warning
else
None
def
get_next_batch_to_run
(
self
):
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
# Merge the prefill batch into the running batch
# Merge the prefill batch into the running batch
if
(
if
(
self
.
last_batch
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
and
not
self
.
last_batch
.
is_empty
()
):
):
if
self
.
current_inflight_req
:
if
self
.
being_chunked_req
:
self
.
last_batch
.
filter_batch
(
self
.
last_batch
.
filter_batch
(
being_chunked_req
=
self
.
being_chunked_req
)
current_inflight_req
=
self
.
current_inflight_req
self
.
tree_cache
.
cache_unfinished_req
(
self
.
being_chunked_req
)
)
# Being chunked request keeps its rid but will get a new req_pool_idx.
self
.
tree_cache
.
cache_unfinished_req
(
self
.
current_inflight_req
)
self
.
req_to_token_pool
.
free
(
self
.
being_chunked_req
.
req_pool_idx
)
# Inflight request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
current_inflight_req
.
req_pool_idx
)
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
if
self
.
running_batch
is
None
:
...
@@ -543,7 +534,7 @@ class Scheduler:
...
@@ -543,7 +534,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
# Handle the cases where prefill is not allowed
if
(
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
current_inflight
_req
is
None
:
)
and
self
.
being_chunked
_req
is
None
:
return
None
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
...
@@ -566,15 +557,6 @@ class Scheduler:
...
@@ -566,15 +557,6 @@ class Scheduler:
num_mixed_running
,
num_mixed_running
,
)
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
has_inflight
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
:
if
self
.
lora_paths
:
lora_set
=
(
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
...
@@ -582,6 +564,13 @@ class Scheduler:
...
@@ -582,6 +564,13 @@ class Scheduler:
else
set
([])
else
set
([])
)
)
# NOTE: if there is request being chunked, we always add it first
has_being_chunked
=
self
.
being_chunked_req
is
not
None
if
has_being_chunked
:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self
.
being_chunked_req
.
init_next_round_input
()
adder
.
add_being_chunked_req
(
self
.
being_chunked_req
)
# Get requests from the waiting queue to a new prefill batch
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
(
if
(
...
@@ -615,12 +604,8 @@ class Scheduler:
...
@@ -615,12 +604,8 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
]
if
adder
.
new_inflight_req
is
not
None
:
# Update new round being chunked request
assert
self
.
current_inflight_req
is
None
self
.
being_chunked_req
=
adder
.
new_chunked_req
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
self
.
current_inflight_req
:
self
.
current_inflight_req
.
is_inflight_req
+=
1
# Print stats
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
...
@@ -649,7 +634,7 @@ class Scheduler:
...
@@ -649,7 +634,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
)
)
else
:
else
:
logger
.
info
(
logger
.
info
(
...
@@ -660,7 +645,7 @@ class Scheduler:
...
@@ -660,7 +645,7 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
)
)
# Create a new batch
# Create a new batch
...
@@ -709,7 +694,7 @@ class Scheduler:
...
@@ -709,7 +694,7 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
retracted_reqs
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
new_token_ratio
-
global_config
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
self
.
min_new_token_ratio
,
)
)
...
@@ -783,10 +768,8 @@ class Scheduler:
...
@@ -783,10 +768,8 @@ class Scheduler:
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
is_inflight_req
>
0
:
if
not
req
.
is_being_chunked
:
req
.
is_inflight_req
-=
1
# Being chunked reqs' prefill is not finished
else
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
req
.
check_finished
()
...
@@ -812,10 +795,8 @@ class Scheduler:
...
@@ -812,10 +795,8 @@ class Scheduler:
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
req
.
embedding
=
embeddings
[
i
]
if
req
.
is_inflight_req
>
0
:
if
not
req
.
is_being_chunked
:
req
.
is_inflight_req
-=
1
# Being chunked reqs' prefill is not finished
else
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
req
.
check_finished
()
...
...
python/sglang/test/test_utils.py
View file @
a2f5e755
...
@@ -660,6 +660,7 @@ def run_mmlu_test(
...
@@ -660,6 +660,7 @@ def run_mmlu_test(
chunked_prefill_size
=
32
,
chunked_prefill_size
=
32
,
):
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
+=
[
"--mem-fraction-static"
,
"0.85"
]
if
disable_radix_cache
:
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
if
enable_mixed_chunk
:
...
...
test/srt/run_suite.py
View file @
a2f5e755
...
@@ -5,6 +5,7 @@ from sglang.test.test_utils import run_unittest_files
...
@@ -5,6 +5,7 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
suites
=
{
"minimal"
:
[
"minimal"
:
[
"test_radix_attention.py"
,
"models/test_embedding_models.py"
,
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
"models/test_lora.py"
,
...
...
test/srt/test_radix_attention.py
0 → 100644
View file @
a2f5e755
import
os
import
random
import
unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
popen_launch_server
,
)
def
gen_radix_tree
(
num_nodes
=
400
,
chunk_len
=
256
):
num0
=
num_nodes
//
2
num1
=
num_nodes
-
num0
nodes
=
[{
"input_ids"
:
[
37
]
*
117
,
"decode_len"
:
217
}]
for
_
in
range
(
num0
):
parent
=
random
.
choice
(
nodes
)
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
while
num1
>
0
:
num_branch
=
random
.
randint
(
1
,
min
(
num1
,
10
))
parent
=
random
.
choice
(
nodes
)
for
_
in
range
(
num_branch
):
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
num1
-=
num_branch
random
.
shuffle
(
nodes
)
return
nodes
def
run_test
(
base_url
,
nodes
):
data
=
{
"input_ids"
:
[
node
[
"input_ids"
]
for
node
in
nodes
],
"sampling_params"
:
[
{
"max_new_tokens"
:
node
[
"decode_len"
],
"temperature"
:
0
}
for
node
in
nodes
],
}
res
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"fcfs"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_radix_attention
(
self
):
nodes
=
gen_radix_tree
()
run_test
(
self
.
base_url
,
nodes
)
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"lpm"
,
],
)
if
__name__
==
"__main__"
:
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment