Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3d40794f
Unverified
Commit
3d40794f
authored
Sep 24, 2025
by
Zhiqiang Xie
Committed by
GitHub
Sep 25, 2025
Browse files
[HiCache] Cleaning the deprecated host memory state (#10778)
parent
c1f39013
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
130 deletions
+15
-130
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+5
-22
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+1
-5
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+9
-103
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
3d40794f
...
@@ -462,7 +462,6 @@ class HiCacheController:
...
@@ -462,7 +462,6 @@ class HiCacheController:
host_indices
=
self
.
mem_pool_host
.
alloc
(
len
(
device_indices
))
host_indices
=
self
.
mem_pool_host
.
alloc
(
len
(
device_indices
))
if
host_indices
is
None
:
if
host_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_write
(
host_indices
)
self
.
write_queue
.
append
(
self
.
write_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
)
...
@@ -486,7 +485,6 @@ class HiCacheController:
...
@@ -486,7 +485,6 @@ class HiCacheController:
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
finish_event
.
record
()
finish_event
.
record
()
# NOTE: We must save the host indices and device indices here,
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# this is because we need to guarantee that these tensors are
...
@@ -510,7 +508,6 @@ class HiCacheController:
...
@@ -510,7 +508,6 @@ class HiCacheController:
device_indices
=
self
.
mem_pool_device_allocator
.
alloc
(
len
(
host_indices
))
device_indices
=
self
.
mem_pool_device_allocator
.
alloc
(
len
(
host_indices
))
if
device_indices
is
None
:
if
device_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
self
.
load_queue
.
append
(
self
.
load_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
)
...
@@ -555,7 +552,6 @@ class HiCacheController:
...
@@ -555,7 +552,6 @@ class HiCacheController:
self
.
io_backend
,
self
.
io_backend
,
)
)
producer_event
.
complete
(
i
)
producer_event
.
complete
(
i
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
# NOTE: We must save the host indices and device indices here,
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
# still alive when the load stream is executing.
...
@@ -573,29 +569,16 @@ class HiCacheController:
...
@@ -573,29 +569,16 @@ class HiCacheController:
)
)
return
producer_id
return
producer_id
def
evict_device
(
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
)
->
int
:
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
self
.
mem_pool_device_allocator
.
free
(
device_indices
)
)
->
int
:
return
len
(
device_indices
)
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
self
.
mem_pool_device_allocator
.
free
(
device_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
return
len
(
device_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
evict_host
(
self
,
host_indices
:
torch
.
Tensor
,
backup_only
:
bool
=
True
)
->
int
:
def
evict_host
(
self
,
host_indices
:
torch
.
Tensor
,
backup_only
:
bool
=
True
)
->
int
:
if
not
backup_only
:
if
not
backup_only
:
raise
ValueError
(
"Other eviction policies are not supported yet."
)
raise
ValueError
(
"Other eviction policies are not supported yet."
)
if
self
.
mem_pool_host
.
is_backup
(
host_indices
):
self
.
mem_pool_host
.
free
(
host_indices
)
self
.
mem_pool_host
.
free
(
host_indices
)
return
len
(
host_indices
)
return
len
(
host_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
prefetch
(
def
prefetch
(
self
,
self
,
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
3d40794f
...
@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
...
@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
def
_evict_backuped
(
self
,
node
:
TreeNode
):
def
_evict_backuped
(
self
,
node
:
TreeNode
):
# evict a node already written to host
# evict a node already written to host
num_evicted
=
self
.
cache_controller
.
evict_device
(
node
.
value
,
node
.
host_value
)
num_evicted
=
self
.
cache_controller
.
evict_device
(
node
.
value
)
assert
num_evicted
>
0
assert
num_evicted
>
0
self
.
evictable_size_
-=
num_evicted
self
.
evictable_size_
-=
num_evicted
node
.
value
=
None
node
.
value
=
None
...
@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache):
...
@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache):
written_indices
,
written_indices
,
hash_value
[:
min_completed_tokens
//
self
.
page_size
],
hash_value
[:
min_completed_tokens
//
self
.
page_size
],
)
)
if
len
(
written_indices
):
self
.
cache_controller
.
mem_pool_host
.
update_prefetch
(
written_indices
)
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
[:
matched_length
])
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
[:
matched_length
])
self
.
cache_controller
.
append_host_mem_release
(
self
.
cache_controller
.
append_host_mem_release
(
...
@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache):
...
@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache):
# change the reference if the node is evicted
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
# this often happens in the case of KV cache recomputation
node
.
value
=
value
[:
prefix_len
]
node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
node
.
host_value
)
self
.
evictable_size_
+=
len
(
node
.
value
)
self
.
evictable_size_
+=
len
(
node
.
value
)
else
:
else
:
self
.
_inc_hit_count
(
node
,
chunked
)
self
.
_inc_hit_count
(
node
,
chunked
)
...
@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache):
...
@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache):
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
if
new_node
.
evicted
:
if
new_node
.
evicted
:
new_node
.
value
=
value
[:
prefix_len
]
new_node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
else
:
else
:
self
.
_inc_hit_count
(
new_node
,
chunked
)
self
.
_inc_hit_count
(
new_node
,
chunked
)
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
3d40794f
...
@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
...
@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
MemoryStateInt
(
IntEnum
):
def
synchronized
(
func
):
IDLE
=
0
@
wraps
(
func
)
RESERVED
=
1
def
wrapper
(
self
,
*
args
,
**
kwargs
):
PROTECTED
=
2
with
self
.
lock
:
SYNCED
=
3
return
func
(
self
,
*
args
,
**
kwargs
)
BACKUP
=
4
def
synchronized
(
debug_only
=
False
):
def
_decorator
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
if
(
not
debug_only
)
or
self
.
debug
:
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
else
:
return
True
return
wrapper
return
_decorato
r
return
wrappe
r
class
HostKVCache
(
abc
.
ABC
):
class
HostKVCache
(
abc
.
ABC
):
...
@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
...
@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
# A lock for synchronized operations on memory allocation and state transitions.
# A lock for synchronized operations on memory allocation and state transitions.
self
.
lock
=
threading
.
RLock
()
self
.
lock
=
threading
.
RLock
()
self
.
debug
=
logger
.
isEnabledFor
(
logging
.
DEBUG
)
self
.
clear
()
self
.
clear
()
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
...
@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
@
synchronized
()
@
synchronized
def
clear
(
self
):
def
clear
(
self
):
# Initialize memory states and tracking structures.
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
self
.
mem_state
=
torch
.
zeros
(
...
@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
...
@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
def
available_size
(
self
):
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
return
len
(
self
.
free_slots
)
@
synchronized
()
@
synchronized
def
alloc
(
self
,
need_size
:
int
)
->
Optional
[
torch
.
Tensor
]:
def
alloc
(
self
,
need_size
:
int
)
->
Optional
[
torch
.
Tensor
]:
assert
(
assert
(
need_size
%
self
.
page_size
==
0
need_size
%
self
.
page_size
==
0
...
@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
...
@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
select_index
=
self
.
free_slots
[:
need_size
]
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
if
self
.
debug
:
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
return
select_index
return
select_index
@
synchronized
()
@
synchronized
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
free_slots
=
torch
.
cat
([
self
.
free_slots
,
indices
])
self
.
free_slots
=
torch
.
cat
([
self
.
free_slots
,
indices
])
if
self
.
debug
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
return
len
(
indices
)
return
len
(
indices
)
@
synchronized
(
debug_only
=
True
)
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
@
synchronized
(
debug_only
=
True
)
def
is_reserved
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
RESERVED
@
synchronized
(
debug_only
=
True
)
def
is_protected
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
is_synced
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
SYNCED
@
synchronized
(
debug_only
=
True
)
def
is_backup
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_backup
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_synced
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_prefetch
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_reserved
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in RESERVED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_synced
(
self
,
indices
:
torch
.
Tensor
):
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
@
synchronized
(
debug_only
=
True
)
def
protect_write
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_reserved
(
indices
):
raise
ValueError
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
protect_load
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_backup
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
complete_io
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_protected
(
indices
):
raise
ValueError
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
class
MHATokenToKVPoolHost
(
HostKVCache
):
class
MHATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MHATokenToKVPool
device_pool
:
MHATokenToKVPool
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment