Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
31b49c0b
"tests/vscode:/vscode.git/clone" did not exist on "32ea2142c056fae722b0cabaa799697a861cd039"
Unverified
Commit
31b49c0b
authored
Oct 05, 2025
by
Ke Bao
Committed by
GitHub
Oct 04, 2025
Browse files
EAGLE cache fix for HiCache (#11215)
parent
d736e0b6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
0 deletions
+90
-0
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+9
-0
test/srt/hicache/test_hicache_eagle.py
test/srt/hicache/test_hicache_eagle.py
+78
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
31b49c0b
...
@@ -583,6 +583,7 @@ class PrefillAdder:
...
@@ -583,6 +583,7 @@ class PrefillAdder:
req
.
prefix_indices
=
torch
.
cat
([
req
.
prefix_indices
,
new_indices
])
req
.
prefix_indices
=
torch
.
cat
([
req
.
prefix_indices
,
new_indices
])
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
prefix_len
=
len
(
req
.
prefix_indices
)
prefix_len
=
len
(
req
.
prefix_indices
)
req
.
last_matched_prefix_len
=
prefix_len
input_tokens
=
self
.
ceil_paged_tokens
(
req
.
extend_input_len
)
input_tokens
=
self
.
ceil_paged_tokens
(
req
.
extend_input_len
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
31b49c0b
...
@@ -762,6 +762,7 @@ class Scheduler(
...
@@ -762,6 +762,7 @@ class Scheduler(
hicache_storage_prefetch_policy
=
server_args
.
hicache_storage_prefetch_policy
,
hicache_storage_prefetch_policy
=
server_args
.
hicache_storage_prefetch_policy
,
model_name
=
server_args
.
served_model_name
,
model_name
=
server_args
.
served_model_name
,
storage_backend_extra_config
=
server_args
.
hicache_storage_backend_extra_config
,
storage_backend_extra_config
=
server_args
.
hicache_storage_backend_extra_config
,
is_eagle
=
self
.
spec_algorithm
.
is_eagle
(),
)
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
self
.
tree_cache
.
cache_controller
.
layer_done_counter
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
31b49c0b
...
@@ -44,6 +44,7 @@ class HiRadixCache(RadixCache):
...
@@ -44,6 +44,7 @@ class HiRadixCache(RadixCache):
hicache_storage_prefetch_policy
:
Optional
[
str
]
=
"best_effort"
,
hicache_storage_prefetch_policy
:
Optional
[
str
]
=
"best_effort"
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
storage_backend_extra_config
:
Optional
[
str
]
=
None
,
storage_backend_extra_config
:
Optional
[
str
]
=
None
,
is_eagle
:
bool
=
False
,
):
):
if
hicache_io_backend
==
"direct"
:
if
hicache_io_backend
==
"direct"
:
...
@@ -135,6 +136,7 @@ class HiRadixCache(RadixCache):
...
@@ -135,6 +136,7 @@ class HiRadixCache(RadixCache):
page_size
,
page_size
,
disable
=
False
,
disable
=
False
,
eviction_policy
=
eviction_policy
,
eviction_policy
=
eviction_policy
,
is_eagle
=
is_eagle
,
)
)
def
_parse_storage_backend_extra_config
(
def
_parse_storage_backend_extra_config
(
...
@@ -658,6 +660,7 @@ class HiRadixCache(RadixCache):
...
@@ -658,6 +660,7 @@ class HiRadixCache(RadixCache):
def
match_prefix
(
self
,
key
:
RadixKey
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
RadixKey
,
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
key
.
token_ids
=
self
.
key_convert_fn
(
key
.
token_ids
)
if
self
.
disable
or
len
(
key
)
==
0
:
if
self
.
disable
or
len
(
key
)
==
0
:
return
MatchResult
(
return
MatchResult
(
device_indices
=
empty_value
,
device_indices
=
empty_value
,
...
@@ -820,9 +823,15 @@ class HiRadixCache(RadixCache):
...
@@ -820,9 +823,15 @@ class HiRadixCache(RadixCache):
return
new_node
return
new_node
def
insert
(
self
,
key
:
RadixKey
,
value
=
None
,
chunked
=
False
):
def
insert
(
self
,
key
:
RadixKey
,
value
=
None
,
chunked
=
False
):
key
.
token_ids
=
self
.
key_convert_fn
(
key
.
token_ids
)
if
len
(
key
)
==
0
:
if
len
(
key
)
==
0
:
return
0
return
0
if
self
.
is_eagle
and
value
is
not
None
:
# Make sure the value len equal to the EAGLE bigram key len
value
=
value
[:
len
(
key
)]
node
=
self
.
root_node
node
=
self
.
root_node
child_key
=
self
.
get_child_key_fn
(
key
)
child_key
=
self
.
get_child_key_fn
(
key
)
total_prefix_length
=
0
total_prefix_length
=
0
...
...
test/srt/hicache/test_hicache_eagle.py
0 → 100644
View file @
31b49c0b
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.bench_serving
import
get_tokenizer
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3
,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestHiCacheEagle
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
tokenizer
=
get_tokenizer
(
cls
.
model
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
,
"--hicache-ratio"
,
1.2
,
"--mem-fraction-static"
,
0.7
,
"--speculative-algorithm"
,
"EAGLE3"
,
"--speculative-draft-model-path"
,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
,
"--speculative-num-steps"
,
2
,
"--speculative-eagle-topk"
,
1
,
"--speculative-num-draft-tokens"
,
3
,
"--dtype"
,
"float16"
,
"--chunked-prefill-size"
,
1024
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.72
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
print
(
f
"
{
server_info
=
}
"
)
avg_spec_accept_length
=
server_info
.
json
()[
"internal_states"
][
0
][
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
2.26
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
31b49c0b
...
@@ -17,6 +17,7 @@ suites = {
...
@@ -17,6 +17,7 @@ suites = {
TestFile
(
"hicache/test_hicache.py"
,
116
),
TestFile
(
"hicache/test_hicache.py"
,
116
),
TestFile
(
"hicache/test_hicache_mla.py"
,
127
),
TestFile
(
"hicache/test_hicache_mla.py"
,
127
),
TestFile
(
"hicache/test_hicache_storage.py"
,
127
),
TestFile
(
"hicache/test_hicache_storage.py"
,
127
),
TestFile
(
"hicache/test_hicache_eagle.py"
,
150
),
TestFile
(
"lora/test_lora.py"
,
200
),
TestFile
(
"lora/test_lora.py"
,
200
),
TestFile
(
"lora/test_lora_eviction.py"
,
200
),
TestFile
(
"lora/test_lora_eviction.py"
,
200
),
TestFile
(
"lora/test_lora_backend.py"
,
99
),
TestFile
(
"lora/test_lora_backend.py"
,
99
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment