Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b0b4f716
Unverified
Commit
b0b4f716
authored
Oct 23, 2025
by
cctry
Committed by
GitHub
Oct 23, 2025
Browse files
[Fix] memory leak by overlap + retract (#11981)
Co-authored-by:
Liangsheng Yin
<
lsyincs@gmail.com
>
parent
6c18addb
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
132 additions
and
25 deletions
+132
-25
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+5
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-2
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+2
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-2
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+19
-5
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
...on/sglang/srt/managers/scheduler_runtime_checker_mixin.py
+53
-0
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+7
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_retract_decode.py
test/srt/test_retract_decode.py
+34
-13
No files found.
python/sglang/srt/environ.py
View file @
b0b4f716
...
@@ -114,7 +114,6 @@ class Envs:
...
@@ -114,7 +114,6 @@ class Envs:
# Test & Debug
# Test & Debug
SGLANG_IS_IN_CI
=
EnvBool
(
False
)
SGLANG_IS_IN_CI
=
EnvBool
(
False
)
SGLANG_IS_IN_CI_AMD
=
EnvBool
(
False
)
SGLANG_IS_IN_CI_AMD
=
EnvBool
(
False
)
SGLANG_TEST_RETRACT
=
EnvBool
(
False
)
SGLANG_SET_CPU_AFFINITY
=
EnvBool
(
False
)
SGLANG_SET_CPU_AFFINITY
=
EnvBool
(
False
)
SGLANG_PROFILE_WITH_STACK
=
EnvBool
(
True
)
SGLANG_PROFILE_WITH_STACK
=
EnvBool
(
True
)
SGLANG_RECORD_STEP_TIME
=
EnvBool
(
False
)
SGLANG_RECORD_STEP_TIME
=
EnvBool
(
False
)
...
@@ -128,6 +127,11 @@ class Envs:
...
@@ -128,6 +127,11 @@ class Envs:
SGLANG_SIMULATE_ACC_METHOD
=
EnvStr
(
"multinomial"
)
SGLANG_SIMULATE_ACC_METHOD
=
EnvStr
(
"multinomial"
)
SGLANG_TORCH_PROFILER_DIR
=
EnvStr
(
"/tmp"
)
SGLANG_TORCH_PROFILER_DIR
=
EnvStr
(
"/tmp"
)
# Scheduler: memory leak test
SGLANG_TEST_RETRACT
=
EnvBool
(
False
)
SGLANG_TEST_RETRACT_INTERVAL
=
EnvInt
(
3
)
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK
=
EnvBool
(
False
)
# Scheduler: new token ratio hyperparameters
# Scheduler: new token ratio hyperparameters
SGLANG_INIT_NEW_TOKEN_RATIO
=
EnvFloat
(
0.7
)
SGLANG_INIT_NEW_TOKEN_RATIO
=
EnvFloat
(
0.7
)
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR
=
EnvFloat
(
0.14
)
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR
=
EnvFloat
(
0.14
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b0b4f716
...
@@ -885,7 +885,6 @@ class Req:
...
@@ -885,7 +885,6 @@ class Req:
self
.
temp_input_top_logprobs_idx
=
None
self
.
temp_input_top_logprobs_idx
=
None
self
.
extend_logprob_start_len
=
0
self
.
extend_logprob_start_len
=
0
self
.
is_chunked
=
0
self
.
is_chunked
=
0
self
.
req_pool_idx
=
None
self
.
mamba_pool_idx
=
None
self
.
mamba_pool_idx
=
None
self
.
already_computed
=
0
self
.
already_computed
=
0
...
@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
new_estimate_ratio
=
(
new_estimate_ratio
=
(
total_decoded_tokens
total_decoded_tokens
+
envs
.
SGLANG_RETRACT_DECODE_STEPS
.
get
()
*
len
(
self
.
reqs
)
+
envs
.
SGLANG_RETRACT_DECODE_STEPS
.
get
()
*
len
(
self
.
reqs
)
)
/
total_max_new_tokens
)
/
(
total_max_new_tokens
+
1
)
# avoid zero division
new_estimate_ratio
=
min
(
1.0
,
new_estimate_ratio
)
new_estimate_ratio
=
min
(
1.0
,
new_estimate_ratio
)
return
retracted_reqs
,
new_estimate_ratio
,
[]
return
retracted_reqs
,
new_estimate_ratio
,
[]
...
@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Only contain fields that will be used by process_batch_result
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
req_to_token_pool
=
self
.
req_to_token_pool
,
req_pool_indices
=
self
.
req_pool_indices
,
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
forward_mode
=
self
.
forward_mode
,
forward_mode
=
self
.
forward_mode
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
b0b4f716
...
@@ -569,7 +569,8 @@ class PrefillAdder:
...
@@ -569,7 +569,8 @@ class PrefillAdder:
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
total_tokens
=
req
.
extend_input_len
+
min
(
total_tokens
=
req
.
extend_input_len
+
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS
max
(
req
.
sampling_params
.
max_new_tokens
-
len
(
req
.
output_ids
),
0
),
CLIP_MAX_NEW_TOKENS
,
)
)
# adjusting the input_tokens based on host_hit_length and page_size
# adjusting the input_tokens based on host_hit_length and page_size
...
...
python/sglang/srt/managers/scheduler.py
View file @
b0b4f716
...
@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
...
@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Test retract decode for debugging purposes
# Test retract decode for debugging purposes
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
TEST_RETRACT
=
envs
.
SGLANG_TEST_RETRACT
.
get
()
TEST_RETRACT_INTERVAL
=
envs
.
SGLANG_TEST_RETRACT_INTERVAL
.
get
()
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
...
@@ -1017,6 +1018,9 @@ class Scheduler(
...
@@ -1017,6 +1018,9 @@ class Scheduler(
self
.
launch_batch_sample_if_needed
(
batch_result
)
self
.
launch_batch_sample_if_needed
(
batch_result
)
self
.
last_batch
=
batch
self
.
last_batch
=
batch
if
envs
.
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK
.
get
():
self
.
_check_runtime_mem_leak
()
def
recv_requests
(
self
)
->
List
[
Req
]:
def
recv_requests
(
self
)
->
List
[
Req
]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
...
@@ -1833,7 +1837,7 @@ class Scheduler(
...
@@ -1833,7 +1837,7 @@ class Scheduler(
# Check if decode out of memory
# Check if decode out of memory
if
not
batch
.
check_decode_mem
(
self
.
decode_mem_cache_buf_multiplier
)
or
(
if
not
batch
.
check_decode_mem
(
self
.
decode_mem_cache_buf_multiplier
)
or
(
TEST_RETRACT
and
batch
.
batch_size
()
>
1
0
TEST_RETRACT
and
self
.
forward_ct
%
TEST_RETRACT_INTERVAL
==
0
):
):
old_ratio
=
self
.
new_token_ratio
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
,
reqs_to_abort
=
batch
.
retract_decode
(
retracted_reqs
,
new_token_ratio
,
reqs_to_abort
=
batch
.
retract_decode
(
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
b0b4f716
...
@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin:
...
@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin:
logprob_pt
=
0
logprob_pt
=
0
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
if
self
.
enable_overlap
and
req
.
is_retracted
and
len
(
req
.
output_ids
)
>
0
:
req_idx
=
batch
.
req_pool_indices
[
i
]
seq_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
pos
=
batch
.
req_to_token_pool
.
req_to_token
[
req_idx
][
seq_len
-
1
:
seq_len
]
self
.
token_to_kv_pool_allocator
.
free
(
pos
)
continue
continue
if
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
req
.
finished
():
if
(
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
(
req
.
finished
()
or
req
.
is_retracted
)
):
# Free the one delayed token for the mixed decode batch
# Free the one delayed token for the mixed decode batch
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
continue
continue
if
req
.
is_retracted
:
continue
if
req
.
is_chunked
<=
0
:
if
req
.
is_chunked
<=
0
:
# req output_ids are set here
# req output_ids are set here
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
...
@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin:
...
@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin:
# We should ignore using next_token_ids for spec decoding cases.
# We should ignore using next_token_ids for spec decoding cases.
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
req
:
Req
req
:
Req
if
req
.
is_retracted
:
continue
if
self
.
enable_overlap
and
req
.
finished
():
if
self
.
enable_overlap
and
(
req
.
finished
()
or
req
.
is_retracted
)
:
indices_to_free
=
None
indices_to_free
=
None
if
batch
.
spec_algorithm
.
is_eagle
():
if
batch
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
...
@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin:
...
@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin:
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
continue
continue
if
req
.
is_retracted
:
continue
new_accepted_len
=
1
new_accepted_len
=
1
if
batch
.
spec_algorithm
.
is_none
():
if
batch
.
spec_algorithm
.
is_none
():
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
...
...
python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
View file @
b0b4f716
...
@@ -4,6 +4,7 @@ import time
...
@@ -4,6 +4,7 @@ import time
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.mamba_radix_cache
import
MambaRadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
...
@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin:
...
@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin:
token_msg
=
f
"
{
self
.
max_total_num_tokens
=
}
,
{
available_size
=
}
,
{
evictable_size
=
}
,
{
protected_size
=
}
\n
"
token_msg
=
f
"
{
self
.
max_total_num_tokens
=
}
,
{
available_size
=
}
,
{
evictable_size
=
}
,
{
protected_size
=
}
\n
"
return
memory_leak
,
token_msg
return
memory_leak
,
token_msg
def
_check_runtime_mem_leak
(
self
:
Scheduler
):
current_batch
:
ScheduleBatch
=
self
.
last_batch
if
current_batch
is
None
:
return
_
,
_
,
available_size
,
evictable_size
=
self
.
_get_token_info
()
protected_size
=
self
.
tree_cache
.
protected_size
()
extend_size
=
0
for
i
,
req
in
enumerate
(
current_batch
.
reqs
):
seq_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
fill_len
=
len
(
req
.
fill_ids
)
if
req
.
fill_ids
is
not
None
else
0
prefix_len
=
(
len
(
req
.
prefix_indices
)
if
req
.
prefix_indices
is
not
None
else
0
)
if
current_batch
.
forward_mode
.
is_decode
():
if
req
.
finished
():
unreleased_len
=
1
else
:
unreleased_len
=
seq_len
-
prefix_len
else
:
unreleased_len
=
fill_len
-
prefix_len
extend_size
+=
unreleased_len
if
(
current_batch
.
forward_mode
.
is_extend
()
and
self
.
running_batch
is
not
None
and
not
self
.
running_batch
.
is_empty
()
and
self
.
running_batch
.
forward_mode
.
is_decode
()
):
for
i
,
req
in
enumerate
(
self
.
running_batch
.
reqs
):
seq_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
prefix_len
=
(
len
(
req
.
prefix_indices
)
if
req
.
prefix_indices
is
not
None
else
0
)
if
req
.
finished
():
unreleased_len
=
0
else
:
unreleased_len
=
seq_len
-
prefix_len
-
1
extend_size
+=
unreleased_len
total_tokens
=
available_size
+
evictable_size
+
protected_size
+
extend_size
assert
(
total_tokens
==
self
.
max_total_num_tokens
),
f
"Mem Leak Detected!
{
total_tokens
=
}
vs
{
self
.
max_total_num_tokens
=
}
"
def
_check_req_pool
(
self
:
Scheduler
):
def
_check_req_pool
(
self
:
Scheduler
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
req_total_size
=
(
req_total_size
=
(
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
b0b4f716
...
@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache):
...
@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache):
else
:
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
protected_size_
=
0
# NOTE (csy): this is to determine if a cache has prefix matching feature.
# NOTE (csy): this is to determine if a cache has prefix matching feature.
# Chunk cache always return True to indicate no prefix matching.
# Chunk cache always return True to indicate no prefix matching.
# TODO (csy): Using a prefix cache trait to replace this
# TODO (csy): Using a prefix cache trait to replace this
...
@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache):
...
@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache):
]
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
self
.
protected_size_
-=
len
(
req
.
prefix_indices
)
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
req
.
fill_ids
)
req
.
req_pool_idx
,
:
len
(
req
.
fill_ids
)
]
]
self
.
protected_size_
+=
len
(
kv_indices
)
-
len
(
req
.
prefix_indices
)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req
.
prefix_indices
=
kv_indices
.
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
req
.
prefix_indices
=
kv_indices
.
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
...
@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache):
...
@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache):
def
dec_lock_ref
(
self
,
node
:
Any
,
swa_uuid_for_lock
:
Optional
[
str
]
=
None
):
def
dec_lock_ref
(
self
,
node
:
Any
,
swa_uuid_for_lock
:
Optional
[
str
]
=
None
):
return
0
return
0
def
protected_size
(
self
):
return
self
.
protected_size_
def
pretty_print
(
self
):
def
pretty_print
(
self
):
return
""
return
""
...
...
test/srt/run_suite.py
View file @
b0b4f716
...
@@ -112,7 +112,7 @@ suites = {
...
@@ -112,7 +112,7 @@ suites = {
TestFile
(
"test_reasoning_parser.py"
,
5
),
TestFile
(
"test_reasoning_parser.py"
,
5
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_request_queue_validation.py"
,
30
),
TestFile
(
"test_request_queue_validation.py"
,
30
),
TestFile
(
"test_retract_decode.py"
,
54
),
TestFile
(
"test_retract_decode.py"
,
90
),
TestFile
(
"test_score_api.py"
,
310
),
TestFile
(
"test_score_api.py"
,
310
),
TestFile
(
"test_server_args.py"
,
1
),
TestFile
(
"test_server_args.py"
,
1
),
TestFile
(
"test_skip_tokenizer_init.py"
,
117
),
TestFile
(
"test_skip_tokenizer_init.py"
,
117
),
...
...
test/srt/test_retract_decode.py
View file @
b0b4f716
import
os
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
sglang.srt.environ
import
envs
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -16,13 +17,12 @@ from sglang.test.test_utils import (
...
@@ -16,13 +17,12 @@ from sglang.test.test_utils import (
class
TestRetractDecode
(
CustomTestCase
):
class
TestRetractDecode
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"1"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
with
envs
.
SGLANG_TEST_RETRACT
.
override
(
True
):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
cls
.
process
=
popen_launch_server
(
)
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
...
@@ -39,22 +39,43 @@ class TestRetractDecode(CustomTestCase):
...
@@ -39,22 +39,43 @@ class TestRetractDecode(CustomTestCase):
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
time
.
sleep
(
1
)
# wait for mem check
assert
self
.
process
.
poll
()
is
None
,
"Server crashed during test"
class
TestRetractDecodeChunkCache
(
CustomTestCase
):
class
TestRetractDecodeChunkCache
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
os
.
environ
[
"SGLANG_TEST_RETRACT"
]
=
"1"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
with
envs
.
SGLANG_TEST_RETRACT
.
override
(
True
):
cls
.
model
,
cls
.
process
=
popen_launch_server
(
cls
.
base_url
,
cls
.
model
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
cls
.
base_url
,
other_args
=
[
"--disable-radix-cache"
,
"--chunked-prefill-size"
,
128
],
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--disable-radix-cache"
,
"--chunked-prefill-size"
,
128
],
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
time
.
sleep
(
1
)
# wait for mem check
assert
self
.
process
.
poll
()
is
None
,
"Server crashed during test"
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment