Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
585e1223
Unverified
Commit
585e1223
authored
Oct 18, 2025
by
Teng Ma
Committed by
GitHub
Oct 18, 2025
Browse files
[HiCache] feat: add more eviction policy (#11506)
parent
a7043c6f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
2 deletions
+31
-2
python/sglang/srt/mem_cache/evict_policy.py
python/sglang/srt/mem_cache/evict_policy.py
+15
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+16
-2
No files found.
python/sglang/srt/mem_cache/evict_policy.py
View file @
585e1223
...
...
@@ -21,3 +21,18 @@ class LRUStrategy(EvictionStrategy):
class
LFUStrategy
(
EvictionStrategy
):
def
get_priority
(
self
,
node
:
"TreeNode"
)
->
Tuple
[
int
,
float
]:
return
(
node
.
hit_count
,
node
.
last_access_time
)
class
FIFOStrategy
(
EvictionStrategy
):
def
get_priority
(
self
,
node
:
"TreeNode"
)
->
float
:
return
node
.
creation_time
class
MRUStrategy
(
EvictionStrategy
):
def
get_priority
(
self
,
node
:
"TreeNode"
)
->
float
:
return
-
node
.
last_access_time
class
FILOStrategy
(
EvictionStrategy
):
def
get_priority
(
self
,
node
:
"TreeNode"
)
->
float
:
return
-
node
.
creation_time
python/sglang/srt/mem_cache/radix_cache.py
View file @
585e1223
...
...
@@ -34,7 +34,14 @@ from sglang.srt.disaggregation.kv_events import (
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
from
sglang.srt.mem_cache.evict_policy
import
EvictionStrategy
,
LFUStrategy
,
LRUStrategy
from
sglang.srt.mem_cache.evict_policy
import
(
EvictionStrategy
,
FIFOStrategy
,
FILOStrategy
,
LFUStrategy
,
LRUStrategy
,
MRUStrategy
,
)
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
if
TYPE_CHECKING
:
...
...
@@ -76,6 +83,7 @@ class TreeNode:
self
.
value
:
Optional
[
torch
.
Tensor
]
=
None
self
.
lock_ref
=
0
self
.
last_access_time
=
time
.
monotonic
()
self
.
creation_time
=
time
.
monotonic
()
self
.
hit_count
=
0
# indicating the node is locked to protect from eviction
...
...
@@ -216,9 +224,15 @@ class RadixCache(BasePrefixCache):
self
.
eviction_strategy
:
EvictionStrategy
=
LRUStrategy
()
elif
eviction_policy
.
lower
()
==
"lfu"
:
self
.
eviction_strategy
:
EvictionStrategy
=
LFUStrategy
()
elif
eviction_policy
.
lower
()
==
"fifo"
:
self
.
eviction_strategy
:
EvictionStrategy
=
FIFOStrategy
()
elif
eviction_policy
.
lower
()
==
"mru"
:
self
.
eviction_strategy
:
EvictionStrategy
=
MRUStrategy
()
elif
eviction_policy
.
lower
()
==
"filo"
:
self
.
eviction_strategy
:
EvictionStrategy
=
FILOStrategy
()
else
:
raise
ValueError
(
f
"Unknown eviction policy:
{
eviction_policy
}
. Supported policies: 'lru', 'lfu'."
f
"Unknown eviction policy:
{
eviction_policy
}
. Supported policies: 'lru', 'lfu'
, 'fifo', 'mru', 'filo'
."
)
self
.
reset
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment