Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c555ce2c
Unverified
Commit
c555ce2c
authored
Oct 25, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 25, 2024
Browse files
Revert "Fix memory leak when doing chunked prefill" (#1797)
parent
40900bae
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
183 deletions
+69
-183
python/sglang/global_config.py
python/sglang/global_config.py
+1
-11
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-3
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+7
-18
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+57
-38
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+0
-1
test/srt/test_radix_attention.py
test/srt/test_radix_attention.py
+0
-112
No files found.
python/sglang/global_config.py
View file @
c555ce2c
...
@@ -15,7 +15,7 @@ class GlobalConfig:
...
@@ -15,7 +15,7 @@ class GlobalConfig:
# Runtime constants: New generation token ratio estimation
# Runtime constants: New generation token ratio estimation
self
.
init_new_token_ratio
=
0.7
self
.
init_new_token_ratio
=
0.7
self
.
min_new_token_ratio
=
0.1
self
.
base_
min_new_token_ratio
=
0.1
self
.
new_token_ratio_decay
=
0.001
self
.
new_token_ratio_decay
=
0.001
# Runtime constants: others
# Runtime constants: others
...
@@ -32,15 +32,5 @@ class GlobalConfig:
...
@@ -32,15 +32,5 @@ class GlobalConfig:
self
.
enable_precache_with_tracing
=
True
self
.
enable_precache_with_tracing
=
True
self
.
enable_parallel_encoding
=
True
self
.
enable_parallel_encoding
=
True
def
adjust_new_token_ratio
(
self
,
schedule_conservativeness
=
1
):
assert
schedule_conservativeness
>=
0
,
"Invalid schedule_conservativeness"
min_new_token_ratio
=
min
(
self
.
min_new_token_ratio
*
schedule_conservativeness
,
1.0
,
)
init_new_token_ratio
=
max
(
self
.
init_new_token_ratio
,
min_new_token_ratio
)
return
min_new_token_ratio
,
init_new_token_ratio
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/srt/managers/schedule_batch.py
View file @
c555ce2c
...
@@ -222,7 +222,7 @@ class Req:
...
@@ -222,7 +222,7 @@ class Req:
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
self
.
is_
being_chunked
=
False
self
.
is_
inflight_req
=
0
# Logprobs (arguments)
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
return_logprob
=
False
...
@@ -906,14 +906,15 @@ class ScheduleBatch:
...
@@ -906,14 +906,15 @@ class ScheduleBatch:
def
filter_batch
(
def
filter_batch
(
self
,
self
,
being_chunked
_req
:
Optional
[
Req
]
=
None
,
current_inflight
_req
:
Optional
[
Req
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
):
if
keep_indices
is
None
:
if
keep_indices
is
None
:
keep_indices
=
[
keep_indices
=
[
i
i
for
i
in
range
(
len
(
self
.
reqs
))
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
being_chunked_req
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
is
not
current_inflight_req
]
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
c555ce2c
...
@@ -136,7 +136,7 @@ class PrefillAdder:
...
@@ -136,7 +136,7 @@ class PrefillAdder:
self
.
req_states
=
None
self
.
req_states
=
None
self
.
can_run_list
=
[]
self
.
can_run_list
=
[]
self
.
new_
chunked
_req
=
None
self
.
new_
inflight
_req
=
None
self
.
log_hit_tokens
=
0
self
.
log_hit_tokens
=
0
self
.
log_input_tokens
=
0
self
.
log_input_tokens
=
0
...
@@ -176,7 +176,7 @@ class PrefillAdder:
...
@@ -176,7 +176,7 @@ class PrefillAdder:
self
.
log_hit_tokens
+=
prefix_len
self
.
log_hit_tokens
+=
prefix_len
self
.
log_input_tokens
+=
extend_input_len
self
.
log_input_tokens
+=
extend_input_len
def
add_
being_chunked
_req
(
self
,
req
:
Req
):
def
add_
inflight
_req
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
...
@@ -192,13 +192,8 @@ class PrefillAdder:
...
@@ -192,13 +192,8 @@ class PrefillAdder:
),
),
)
)
if
truncated
:
# Return if chunked prefill not finished
# Continue to chunk the request
return
req
if
truncated
else
None
assert
req
.
is_being_chunked
self
.
new_chunked_req
=
req
else
:
# Release the being chunked status
req
.
is_being_chunked
=
False
@
contextmanager
@
contextmanager
def
_lock_node
(
self
,
last_node
:
TreeNode
):
def
_lock_node
(
self
,
last_node
:
TreeNode
):
...
@@ -267,14 +262,11 @@ class PrefillAdder:
...
@@ -267,14 +262,11 @@ class PrefillAdder:
)
)
else
:
else
:
# Chunked prefill
# Chunked prefill
assert
self
.
new_chunked_req
is
None
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
is_being_chunked
=
True
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_
chunked
_req
=
req
self
.
new_
inflight
_req
=
req
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
0
,
trunc_len
,
0
)
return
self
.
budget_state
()
return
self
.
budget_state
()
...
@@ -313,18 +305,15 @@ class PrefillAdder:
...
@@ -313,18 +305,15 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
),
)
)
else
:
else
:
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
return
AddReqResult
.
OTHER
# Chunked prefill
assert
self
.
new_chunked_req
is
None
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
is_being_chunked
=
True
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_
chunked
_req
=
req
self
.
new_
inflight
_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
self
.
_prefill_one_req
(
prefix_len
,
trunc_len
,
0
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
c555ce2c
...
@@ -219,12 +219,13 @@ class Scheduler:
...
@@ -219,12 +219,13 @@ class Scheduler:
# Init chunked prefill
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
being_chunked
_req
=
None
self
.
current_inflight
_req
=
None
self
.
is_mixed_chunk
=
(
self
.
is_mixed_chunk
=
(
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
self
.
chunked_prefill_size
is
not
None
and
server_args
.
enable_mixed_chunk
)
)
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
server_args
.
tokenizer_path
,
{
{
...
@@ -237,10 +238,16 @@ class Scheduler:
...
@@ -237,10 +238,16 @@ class Scheduler:
self
.
jump_forward_cache
=
JumpForwardCache
()
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
# Init new token estimation
self
.
min_new_token_ratio
,
self
.
init_new_token_ratio
=
(
assert
(
global_config
.
adjust_new_token_ratio
(
server_args
.
schedule_conservativeness
)
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
)
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
# Init profiler
# Init profiler
...
@@ -287,7 +294,7 @@ class Scheduler:
...
@@ -287,7 +294,7 @@ class Scheduler:
self
.
process_batch_result
(
batch
,
result
)
self
.
process_batch_result
(
batch
,
result
)
else
:
else
:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -314,7 +321,7 @@ class Scheduler:
...
@@ -314,7 +321,7 @@ class Scheduler:
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
elif
batch
is
None
:
elif
batch
is
None
:
self
.
check_memory
()
self
.
check_memory
()
self
.
new_token_ratio
=
self
.
init_new_token_ratio
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
self
.
last_batch
=
batch
self
.
last_batch
=
batch
...
@@ -492,18 +499,20 @@ class Scheduler:
...
@@ -492,18 +499,20 @@ class Scheduler:
)
)
exit
(
1
)
if
crash_on_warning
else
None
exit
(
1
)
if
crash_on_warning
else
None
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]
:
def
get_next_batch_to_run
(
self
):
# Merge the prefill batch into the running batch
# Merge the prefill batch into the running batch
if
(
if
(
self
.
last_batch
self
.
last_batch
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
forward_mode
.
is_decode
()
and
not
self
.
last_batch
.
is_empty
()
and
not
self
.
last_batch
.
is_empty
()
):
):
if
self
.
being_chunked_req
:
if
self
.
current_inflight_req
:
self
.
last_batch
.
filter_batch
(
being_chunked_req
=
self
.
being_chunked_req
)
self
.
last_batch
.
filter_batch
(
self
.
tree_cache
.
cache_unfinished_req
(
self
.
being_chunked_req
)
current_inflight_req
=
self
.
current_inflight_req
# Being chunked request keeps its rid but will get a new req_pool_idx.
)
self
.
req_to_token_pool
.
free
(
self
.
being_chunked_req
.
req_pool_idx
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
current_inflight_req
)
# Inflight request keeps its rid but will get a new req_pool_idx.
self
.
req_to_token_pool
.
free
(
self
.
current_inflight_req
.
req_pool_idx
)
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
if
not
self
.
last_batch
.
is_empty
():
if
not
self
.
last_batch
.
is_empty
():
if
self
.
running_batch
is
None
:
if
self
.
running_batch
is
None
:
...
@@ -534,7 +543,7 @@ class Scheduler:
...
@@ -534,7 +543,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
# Handle the cases where prefill is not allowed
if
(
if
(
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
self
.
batch_is_full
or
len
(
self
.
waiting_queue
)
==
0
)
and
self
.
being_chunked
_req
is
None
:
)
and
self
.
current_inflight
_req
is
None
:
return
None
return
None
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
...
@@ -557,6 +566,15 @@ class Scheduler:
...
@@ -557,6 +566,15 @@ class Scheduler:
num_mixed_running
,
num_mixed_running
,
)
)
has_inflight
=
self
.
current_inflight_req
is
not
None
if
has_inflight
:
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
if
self
.
lora_paths
:
if
self
.
lora_paths
:
lora_set
=
(
lora_set
=
(
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
set
([
req
.
lora_path
for
req
in
self
.
running_batch
.
reqs
])
...
@@ -564,13 +582,6 @@ class Scheduler:
...
@@ -564,13 +582,6 @@ class Scheduler:
else
set
([])
else
set
([])
)
)
# NOTE: if there is request being chunked, we always add it first
has_being_chunked
=
self
.
being_chunked_req
is
not
None
if
has_being_chunked
:
# NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result
self
.
being_chunked_req
.
init_next_round_input
()
adder
.
add_being_chunked_req
(
self
.
being_chunked_req
)
# Get requests from the waiting queue to a new prefill batch
# Get requests from the waiting queue to a new prefill batch
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
if
(
if
(
...
@@ -604,8 +615,12 @@ class Scheduler:
...
@@ -604,8 +615,12 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
]
# Update new round being chunked request
if
adder
.
new_inflight_req
is
not
None
:
self
.
being_chunked_req
=
adder
.
new_chunked_req
assert
self
.
current_inflight_req
is
None
self
.
current_inflight_req
=
adder
.
new_inflight_req
if
self
.
current_inflight_req
:
self
.
current_inflight_req
.
is_inflight_req
+=
1
# Print stats
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
...
@@ -634,7 +649,7 @@ class Scheduler:
...
@@ -634,7 +649,7 @@ class Scheduler:
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"#cached-token:
{
adder
.
log_hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
)
)
else
:
else
:
logger
.
info
(
logger
.
info
(
...
@@ -645,7 +660,7 @@ class Scheduler:
...
@@ -645,7 +660,7 @@ class Scheduler:
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#running-req:
{
running_bs
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
being_chunked
}
"
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
+
has_
inflight
}
"
)
)
# Create a new batch
# Create a new batch
...
@@ -694,7 +709,7 @@ class Scheduler:
...
@@ -694,7 +709,7 @@ class Scheduler:
self
.
waiting_queue
.
extend
(
retracted_reqs
)
self
.
waiting_queue
.
extend
(
retracted_reqs
)
else
:
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
global_config
.
new_token_ratio_decay
,
self
.
new_token_ratio
-
self
.
new_token_ratio_decay
,
self
.
min_new_token_ratio
,
self
.
min_new_token_ratio
,
)
)
...
@@ -768,8 +783,10 @@ class Scheduler:
...
@@ -768,8 +783,10 @@ class Scheduler:
# Check finish conditions
# Check finish conditions
logprob_pt
=
0
logprob_pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
not
req
.
is_being_chunked
:
if
req
.
is_inflight_req
>
0
:
# Being chunked reqs' prefill is not finished
req
.
is_inflight_req
-=
1
else
:
# Inflight reqs' prefill is not finished
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
req
.
check_finished
()
...
@@ -795,8 +812,10 @@ class Scheduler:
...
@@ -795,8 +812,10 @@ class Scheduler:
# Check finish conditions
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
embedding
=
embeddings
[
i
]
req
.
embedding
=
embeddings
[
i
]
if
not
req
.
is_being_chunked
:
if
req
.
is_inflight_req
>
0
:
# Being chunked reqs' prefill is not finished
req
.
is_inflight_req
-=
1
else
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
req
.
check_finished
()
...
...
python/sglang/test/test_utils.py
View file @
c555ce2c
...
@@ -663,7 +663,6 @@ def run_mmlu_test(
...
@@ -663,7 +663,6 @@ def run_mmlu_test(
chunked_prefill_size
=
32
,
chunked_prefill_size
=
32
,
):
):
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
=
[
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)]
other_args
+=
[
"--mem-fraction-static"
,
"0.85"
]
if
disable_radix_cache
:
if
disable_radix_cache
:
other_args
+=
[
"--disable-radix-cache"
]
other_args
+=
[
"--disable-radix-cache"
]
if
enable_mixed_chunk
:
if
enable_mixed_chunk
:
...
...
test/srt/test_radix_attention.py
deleted
100644 → 0
View file @
40900bae
import
os
import
random
import
unittest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
kill_child_process
,
popen_launch_server
,
)
def
gen_radix_tree
(
num_nodes
=
400
,
chunk_len
=
256
):
num0
=
num_nodes
//
2
num1
=
num_nodes
-
num0
nodes
=
[{
"input_ids"
:
[
37
]
*
117
,
"decode_len"
:
217
}]
for
_
in
range
(
num0
):
parent
=
random
.
choice
(
nodes
)
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
while
num1
>
0
:
num_branch
=
random
.
randint
(
1
,
min
(
num1
,
10
))
parent
=
random
.
choice
(
nodes
)
for
_
in
range
(
num_branch
):
unique_len
=
random
.
randint
(
0
,
chunk_len
)
decode_len
=
random
.
randint
(
0
,
chunk_len
)
token_id
=
random
.
randint
(
0
,
32000
)
child
=
{
"input_ids"
:
parent
[
"input_ids"
]
+
[
token_id
]
*
unique_len
,
"decode_len"
:
decode_len
,
}
nodes
.
append
(
child
)
num1
-=
num_branch
random
.
shuffle
(
nodes
)
return
nodes
def
run_test
(
base_url
,
nodes
):
data
=
{
"input_ids"
:
[
node
[
"input_ids"
]
for
node
in
nodes
],
"sampling_params"
:
[
{
"max_new_tokens"
:
node
[
"decode_len"
],
"temperature"
:
0
}
for
node
in
nodes
],
}
res
=
requests
.
post
(
base_url
+
"/generate"
,
json
=
data
)
assert
res
.
status_code
==
200
class
TestRadixCacheFCFS
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"fcfs"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_radix_attention
(
self
):
nodes
=
gen_radix_tree
()
run_test
(
self
.
base_url
,
nodes
)
class
TestRadixCacheLPM
(
TestRadixCacheFCFS
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--chunked-prefill-size"
,
"128"
,
"--max-total-tokens"
,
"20000"
,
"--schedule-policy"
,
"lpm"
,
],
)
if
__name__
==
"__main__"
:
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"true"
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment