Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a023856b
Unverified
Commit
a023856b
authored
Jun 14, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 14, 2025
Browse files
Move host memory pools into a separate file (#7200)
parent
db0cc57e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
388 additions
and
381 deletions
+388
-381
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+2
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+4
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-377
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+380
-0
test/srt/test_hicache_page.py
test/srt/test_hicache_page.py
+1
-1
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
a023856b
...
@@ -22,7 +22,8 @@ from typing import List, Optional
...
@@ -22,7 +22,8 @@ from typing import List, Optional
import
torch
import
torch
from
sglang.srt.mem_cache.memory_pool
import
HostKVCache
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
a023856b
...
@@ -9,12 +9,14 @@ import torch
...
@@ -9,12 +9,14 @@ import torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MHATokenToKVPool
,
MHATokenToKVPoolHost
,
MLATokenToKVPool
,
MLATokenToKVPool
,
MLATokenToKVPoolHost
,
ReqToTokenPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
TokenToKVPoolAllocator
,
)
)
from
sglang.srt.mem_cache.memory_pool_host
import
(
MHATokenToKVPoolHost
,
MLATokenToKVPoolHost
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
a023856b
...
@@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache.
...
@@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache.
import
abc
import
abc
import
logging
import
logging
import
threading
from
enum
import
IntEnum
from
functools
import
wraps
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
psutil
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
debug_timing
,
get_compiler_backend
,
is_cuda
debug_timing
,
get_compiler_backend
,
is_cuda
,
next_power_of_2
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -772,370 +763,3 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -772,370 +763,3 @@ class DoubleSparseTokenToKVPool(KVCache):
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
pass
pass
class
MemoryStateInt
(
IntEnum
):
IDLE
=
0
RESERVED
=
1
PROTECTED
=
2
SYNCED
=
3
BACKUP
=
4
def
synchronized
(
debug_only
=
False
):
def
_decorator
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
if
(
not
debug_only
)
or
self
.
debug
:
return
func
(
self
,
*
args
,
**
kwargs
)
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
else
:
return
True
return
wrapper
return
_decorator
class
HostKVCache
(
abc
.
ABC
):
def
__init__
(
self
,
device_pool
:
KVCache
,
host_to_device_ratio
:
float
,
host_size
:
int
,
pin_memory
:
bool
,
device
:
str
,
page_size
:
int
,
):
self
.
device_pool
=
device_pool
self
.
dtype
=
device_pool
.
store_dtype
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
page_size
=
page_size
self
.
size_per_token
=
self
.
get_size_per_token
()
if
host_size
>
0
:
self
.
size
=
int
(
host_size
*
1e9
//
self
.
size_per_token
)
else
:
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
# Align the host memory pool size to the page size
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
start_layer
=
device_pool
.
start_layer
self
.
end_layer
=
device_pool
.
end_layer
assert
(
self
.
size
>
device_pool
.
size
),
"The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem
=
psutil
.
virtual_memory
()
requested_bytes
=
self
.
size
*
self
.
size_per_token
# preserve at least 10GB for other usage
ten_gb
=
10
*
(
1024
**
3
)
if
requested_bytes
>
host_mem
.
available
-
ten_gb
:
raise
ValueError
(
f
"Not enough host memory available. Requesting "
f
"
{
requested_bytes
/
1e9
:.
2
f
}
GB but only have "
f
"
{
host_mem
.
available
/
1e9
:.
2
f
}
GB free. Please reduce the "
f
"size of the hierarchical cache."
)
else
:
logger
.
info
(
f
"Allocating
{
requested_bytes
/
1e9
:.
2
f
}
GB host memory for hierarchical KV cache."
)
self
.
kv_buffer
=
self
.
init_kv_buffer
()
# A lock for synchronized operations on memory allocation and state transitions.
self
.
lock
=
threading
.
RLock
()
self
.
debug
=
logger
.
isEnabledFor
(
logging
.
DEBUG
)
self
.
clear
()
@
abc
.
abstractmethod
def
get_size_per_token
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
init_kv_buffer
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
transfer
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data
(
self
,
indices
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
assign_flat_data
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
synchronized
()
def
clear
(
self
):
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
(
self
.
size
,),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
)
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
if
need_size
>
self
.
available_size
():
return
None
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
if
self
.
debug
:
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
return
select_index
@
synchronized
()
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
free_slots
=
torch
.
cat
([
self
.
free_slots
,
indices
])
if
self
.
debug
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
return
len
(
indices
)
@
synchronized
(
debug_only
=
True
)
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
@
synchronized
(
debug_only
=
True
)
def
is_reserved
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
RESERVED
@
synchronized
(
debug_only
=
True
)
def
is_protected
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
is_synced
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
SYNCED
@
synchronized
(
debug_only
=
True
)
def
is_backup
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_backup
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_synced
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_synced
(
self
,
indices
:
torch
.
Tensor
):
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
@
synchronized
(
debug_only
=
True
)
def
protect_write
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_reserved
(
indices
):
raise
ValueError
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
protect_load
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_backup
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
complete_io
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_protected
(
indices
):
raise
ValueError
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
class
MHATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MHATokenToKVPool
def
__init__
(
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
self
.
head_num
=
self
.
device_pool
.
head_num
self
.
head_dim
=
self
.
device_pool
.
head_dim
self
.
layer_num
=
self
.
device_pool
.
layer_num
return
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
def
init_kv_buffer
(
self
):
return
torch
.
empty
(
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
0
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
k_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
self
.
kv_buffer
[
1
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
v_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
device_pool
.
v_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
class
MLATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MLATokenToKVPool
def
__init__
(
self
,
device_pool
:
MLATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
self
.
kv_lora_rank
=
self
.
device_pool
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
device_pool
.
qk_rope_head_dim
self
.
layer_num
=
self
.
device_pool
.
layer_num
return
(
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
*
1
*
self
.
dtype
.
itemsize
*
self
.
layer_num
)
def
init_kv_buffer
(
self
):
return
torch
.
empty
(
(
self
.
layer_num
,
self
.
size
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
kv_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
python/sglang/srt/mem_cache/memory_pool_host.py
0 → 100644
View file @
a023856b
import
abc
import
logging
import
threading
from
enum
import
IntEnum
from
functools
import
wraps
import
psutil
import
torch
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
MHATokenToKVPool
,
MLATokenToKVPool
from
sglang.srt.utils
import
debug_timing
logger
=
logging
.
getLogger
(
__name__
)
class
MemoryStateInt
(
IntEnum
):
IDLE
=
0
RESERVED
=
1
PROTECTED
=
2
SYNCED
=
3
BACKUP
=
4
def
synchronized
(
debug_only
=
False
):
def
_decorator
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
if
(
not
debug_only
)
or
self
.
debug
:
return
func
(
self
,
*
args
,
**
kwargs
)
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
else
:
return
True
return
wrapper
return
_decorator
class
HostKVCache
(
abc
.
ABC
):
def
__init__
(
self
,
device_pool
:
KVCache
,
host_to_device_ratio
:
float
,
host_size
:
int
,
pin_memory
:
bool
,
device
:
str
,
page_size
:
int
,
):
self
.
device_pool
=
device_pool
self
.
dtype
=
device_pool
.
store_dtype
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
page_size
=
page_size
self
.
size_per_token
=
self
.
get_size_per_token
()
if
host_size
>
0
:
self
.
size
=
int
(
host_size
*
1e9
//
self
.
size_per_token
)
else
:
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
# Align the host memory pool size to the page size
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
start_layer
=
device_pool
.
start_layer
self
.
end_layer
=
device_pool
.
end_layer
assert
(
self
.
size
>
device_pool
.
size
),
"The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem
=
psutil
.
virtual_memory
()
requested_bytes
=
self
.
size
*
self
.
size_per_token
# preserve at least 10GB for other usage
ten_gb
=
10
*
(
1024
**
3
)
if
requested_bytes
>
host_mem
.
available
-
ten_gb
:
raise
ValueError
(
f
"Not enough host memory available. Requesting "
f
"
{
requested_bytes
/
1e9
:.
2
f
}
GB but only have "
f
"
{
host_mem
.
available
/
1e9
:.
2
f
}
GB free. Please reduce the "
f
"size of the hierarchical cache."
)
else
:
logger
.
info
(
f
"Allocating
{
requested_bytes
/
1e9
:.
2
f
}
GB host memory for hierarchical KV cache."
)
self
.
kv_buffer
=
self
.
init_kv_buffer
()
# A lock for synchronized operations on memory allocation and state transitions.
self
.
lock
=
threading
.
RLock
()
self
.
debug
=
logger
.
isEnabledFor
(
logging
.
DEBUG
)
self
.
clear
()
@
abc
.
abstractmethod
def
get_size_per_token
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
init_kv_buffer
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
transfer
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data
(
self
,
indices
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
assign_flat_data
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
synchronized
()
def
clear
(
self
):
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
(
self
.
size
,),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
)
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
if
need_size
>
self
.
available_size
():
return
None
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
if
self
.
debug
:
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
return
select_index
@
synchronized
()
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
free_slots
=
torch
.
cat
([
self
.
free_slots
,
indices
])
if
self
.
debug
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
return
len
(
indices
)
@
synchronized
(
debug_only
=
True
)
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
@
synchronized
(
debug_only
=
True
)
def
is_reserved
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
RESERVED
@
synchronized
(
debug_only
=
True
)
def
is_protected
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
is_synced
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
SYNCED
@
synchronized
(
debug_only
=
True
)
def
is_backup
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_backup
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_synced
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
(
debug_only
=
True
)
def
update_synced
(
self
,
indices
:
torch
.
Tensor
):
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
@
synchronized
(
debug_only
=
True
)
def
protect_write
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_reserved
(
indices
):
raise
ValueError
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
protect_load
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_backup
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
(
debug_only
=
True
)
def
complete_io
(
self
,
indices
:
torch
.
Tensor
):
if
not
self
.
is_protected
(
indices
):
raise
ValueError
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
class
MHATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MHATokenToKVPool
def
__init__
(
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
self
.
head_num
=
self
.
device_pool
.
head_num
self
.
head_dim
=
self
.
device_pool
.
head_dim
self
.
layer_num
=
self
.
device_pool
.
layer_num
return
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
def
init_kv_buffer
(
self
):
return
torch
.
empty
(
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
0
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
k_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
self
.
kv_buffer
[
1
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
v_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
device_pool
.
v_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
class
MLATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MLATokenToKVPool
def
__init__
(
self
,
device_pool
:
MLATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
self
.
kv_lora_rank
=
self
.
device_pool
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
device_pool
.
qk_rope_head_dim
self
.
layer_num
=
self
.
device_pool
.
layer_num
return
(
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
*
1
*
self
.
dtype
.
itemsize
*
self
.
layer_num
)
def
init_kv_buffer
(
self
):
return
torch
.
empty
(
(
self
.
layer_num
,
self
.
size
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
kv_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
test/srt/test_hicache_page.py
View file @
a023856b
...
@@ -26,7 +26,7 @@ class TestHiCachePage(CustomTestCase):
...
@@ -26,7 +26,7 @@ class TestHiCachePage(CustomTestCase):
"--page-size"
,
"--page-size"
,
32
,
32
,
"--hicache-write-policy"
,
"--hicache-write-policy"
,
"write
-
back"
,
"write
_
back"
,
],
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment