Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd7ca006
Unverified
Commit
dd7ca006
authored
Jul 31, 2025
by
Zhiqiang Xie
Committed by
GitHub
Aug 01, 2025
Browse files
Interface change for kvcache io to support page first layout (#8318)
parent
9305ea6c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
371 additions
and
171 deletions
+371
-171
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+5
-17
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+19
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-118
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+321
-33
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-1
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
dd7ca006
...
@@ -231,16 +231,7 @@ class HiCacheController:
...
@@ -231,16 +231,7 @@ class HiCacheController:
self
.
mem_pool_host
=
mem_pool_host
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
write_policy
=
write_policy
self
.
page_size
=
page_size
self
.
page_size
=
page_size
# using kernel for small page KV cache transfer and DMA for large pages
self
.
io_backend
=
io_backend
if
not
io_backend
:
IO_BACKEND_PAGE_SIZE_THRESHOLD
=
64
self
.
io_backend
=
(
"direct"
if
self
.
page_size
>=
IO_BACKEND_PAGE_SIZE_THRESHOLD
else
"kernel"
)
else
:
self
.
io_backend
=
io_backend
self
.
enable_storage
=
False
self
.
enable_storage
=
False
# todo: move backend initialization to storage backend module
# todo: move backend initialization to storage backend module
...
@@ -447,11 +438,8 @@ class HiCacheController:
...
@@ -447,11 +438,8 @@ class HiCacheController:
host_indices
,
device_indices
=
self
.
move_indices
(
host_indices
,
device_indices
=
self
.
move_indices
(
operation
.
host_indices
,
operation
.
device_indices
operation
.
host_indices
,
operation
.
device_indices
)
)
self
.
mem_pool_device
.
backup_to_host_all_layer
(
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_host
,
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
host_indices
,
device_indices
,
self
.
io_backend
,
)
)
self
.
write_stream
.
synchronize
()
self
.
write_stream
.
synchronize
()
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
...
@@ -491,8 +479,8 @@ class HiCacheController:
...
@@ -491,8 +479,8 @@ class HiCacheController:
batch_operation
.
host_indices
,
batch_operation
.
device_indices
batch_operation
.
host_indices
,
batch_operation
.
device_indices
)
)
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
self
.
mem_pool_
device
.
load_from_host
_per_layer
(
self
.
mem_pool_
host
.
load_to_device
_per_layer
(
self
.
mem_pool_
host
,
self
.
mem_pool_
device
,
host_indices
,
host_indices
,
device_indices
,
device_indices
,
i
,
i
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
dd7ca006
...
@@ -588,6 +588,7 @@ class Scheduler(
...
@@ -588,6 +588,7 @@ class Scheduler(
==
"fa3"
# hot fix for incompatibility
==
"fa3"
# hot fix for incompatibility
else
server_args
.
hicache_io_backend
else
server_args
.
hicache_io_backend
),
),
hicache_mem_layout
=
server_args
.
hicache_mem_layout
,
hicache_storage_backend
=
server_args
.
hicache_storage_backend
,
hicache_storage_backend
=
server_args
.
hicache_storage_backend
,
)
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
dd7ca006
...
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
...
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
hicache_size
:
int
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
hicache_write_policy
:
str
,
hicache_io_backend
:
str
,
hicache_io_backend
:
str
,
hicache_mem_layout
:
str
,
hicache_storage_backend
:
Optional
[
str
]
=
None
,
hicache_storage_backend
:
Optional
[
str
]
=
None
,
):
):
if
hicache_io_backend
==
"direct"
:
if
hicache_mem_layout
==
"page_first"
:
hicache_mem_layout
=
"layer_first"
logger
.
warning
(
"Page first layout is not supported with direct IO backend, switching to layer first layout"
)
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
,
hicache_mem_layout
,
)
)
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
,
hicache_mem_layout
,
)
)
else
:
else
:
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
dd7ca006
...
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
...
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
is_npu
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
next_power_of_2
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
GB
=
1024
*
1024
*
1024
GB
=
1024
*
1024
*
1024
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
if
not
_is_npu
:
from
sgl_kernel.kvcacheio
import
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
class
ReqToTokenPool
:
class
ReqToTokenPool
:
...
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
...
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
)
->
None
:
)
->
None
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
raise
NotImplementedError
()
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
self
.
layer_transfer_counter
=
layer_transfer_counter
...
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
...
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
)
)
for
_
in
range
(
self
.
layer_num
)
for
_
in
range
(
self
.
layer_num
)
]
]
self
.
token_stride
=
self
.
head_num
*
self
.
head_dim
self
.
data_ptrs
=
torch
.
tensor
(
self
.
k_data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
k_buffer
+
self
.
v_buffer
],
[
x
.
data_ptr
()
for
x
in
self
.
k_buffer
],
dtype
=
torch
.
uint64
,
device
=
self
.
device
,
)
self
.
v_data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
v_buffer
],
dtype
=
torch
.
uint64
,
dtype
=
torch
.
uint64
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
data_ptrs
=
torch
.
cat
([
self
.
k_data_ptrs
,
self
.
v_data_ptrs
],
dim
=
0
)
self
.
data_strides
=
torch
.
tensor
(
self
.
data_strides
=
torch
.
tensor
(
[
[
np
.
prod
(
x
.
shape
[
1
:])
*
x
.
dtype
.
itemsize
np
.
prod
(
x
.
shape
[
1
:])
*
x
.
dtype
.
itemsize
...
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
...
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
self
.
v_buffer
[
layer_id
][
chunk_indices
]
=
v_chunk
self
.
v_buffer
[
layer_id
][
chunk_indices
]
=
v_chunk
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
,
):
transfer_kv_per_layer
(
src_k
=
host_pool
.
k_buffer
[
layer_id
],
dst_k
=
self
.
k_buffer
[
layer_id
],
src_v
=
host_pool
.
v_buffer
[
layer_id
],
dst_v
=
self
.
v_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for
layer_id
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
):
if
layer_id
-
self
.
start_layer
>=
len
(
host_pool
.
k_buffer
):
raise
ValueError
(
f
"Layer ID
{
layer_id
}
exceeds the number of layers in host pool."
)
transfer_kv_per_layer
(
src_k
=
self
.
k_buffer
[
layer_id
],
dst_k
=
host_pool
.
k_buffer
[
layer_id
],
src_v
=
self
.
v_buffer
[
layer_id
],
dst_v
=
host_pool
.
v_buffer
[
layer_id
],
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
_get_key_buffer
(
self
,
layer_id
:
int
):
def
_get_key_buffer
(
self
,
layer_id
:
int
):
# for internal use of referencing
# for internal use of referencing
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
...
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
...
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
layer_id_override
=
layer_id_pool
,
layer_id_override
=
layer_id_pool
,
)
)
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
raise
NotImplementedError
(
"HiCache not supported for SWAKVPool."
)
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
raise
NotImplementedError
(
"HiCache not supported for SWAKVPool."
)
class
AscendTokenToKVPool
(
MHATokenToKVPool
):
class
AscendTokenToKVPool
(
MHATokenToKVPool
):
...
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
...
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
self
.
token_stride
=
kv_lora_rank
+
qk_rope_head_dim
self
.
data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
kv_buffer
],
dtype
=
torch
.
uint64
,
device
=
self
.
device
,
)
self
.
layer_transfer_counter
=
None
self
.
layer_transfer_counter
=
None
kv_size
=
self
.
get_kv_size_bytes
()
kv_size
=
self
.
get_kv_size_bytes
()
...
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
...
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
self
.
kv_buffer
[
layer_id
],
loc
,
cache_k_nope
,
cache_k_rope
self
.
kv_buffer
[
layer_id
],
loc
,
cache_k_nope
,
cache_k_rope
)
)
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
transfer_kv_per_layer_mla
(
src
=
host_pool
.
kv_buffer
[
layer_id
],
dst
=
self
.
kv_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for
layer_id
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
):
if
layer_id
-
self
.
start_layer
>=
len
(
host_pool
.
kv_buffer
):
raise
ValueError
(
f
"Layer ID
{
layer_id
}
exceeds the number of layers in host pool."
)
transfer_kv_per_layer_mla
(
src
=
self
.
kv_buffer
[
layer_id
],
dst
=
host_pool
.
kv_buffer
[
layer_id
],
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
get_cpu_copy
(
self
,
indices
):
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
kv_cache_cpu
=
[]
kv_cache_cpu
=
[]
...
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_label
self
.
label_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_label
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
raise
NotImplementedError
(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
raise
NotImplementedError
(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
@
triton
.
jit
@
triton
.
jit
def
copy_all_layer_kv_cache
(
def
copy_all_layer_kv_cache
(
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
dd7ca006
...
@@ -8,6 +8,21 @@ import psutil
...
@@ -8,6 +8,21 @@ import psutil
import
torch
import
torch
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
MHATokenToKVPool
,
MLATokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
MHATokenToKVPool
,
MLATokenToKVPool
from
sglang.srt.utils
import
is_npu
_is_npu
=
is_npu
()
if
not
_is_npu
:
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_lf_pf
,
transfer_kv_all_layer_mla
,
transfer_kv_all_layer_mla_lf_pf
,
transfer_kv_direct
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
transfer_kv_per_layer_mla_pf_lf
,
transfer_kv_per_layer_pf_lf
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
...
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
device_pool
:
KVCache
,
device_pool
:
KVCache
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
host_size
:
int
,
page_size
:
int
,
layout
:
str
,
pin_memory
:
bool
,
pin_memory
:
bool
,
device
:
str
,
device
:
str
,
page_size
:
int
,
):
):
self
.
device_pool
=
device_pool
self
.
device_pool
=
device_pool
self
.
dtype
=
device_pool
.
store_dtype
self
.
page_size
=
page_size
self
.
layout
=
layout
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
device
=
device
self
.
page_size
=
page_size
self
.
dtype
=
device_pool
.
store_dtype
self
.
size_per_token
=
self
.
get_size_per_token
()
self
.
size_per_token
=
self
.
get_size_per_token
()
if
host_size
>
0
:
if
host_size
>
0
:
self
.
size
=
int
(
host_size
*
1e9
//
self
.
size_per_token
)
self
.
size
=
int
(
host_size
*
1e9
//
self
.
size_per_token
)
...
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
...
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
def
init_kv_buffer
(
self
):
def
init_kv_buffer
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
load_to_device_per_layer
(
self
,
device_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
)
->
None
:
"""
Load KV data from the host memory pool to the device memory pool for a specific layer.
"""
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
backup_from_device_all_layer
(
self
,
device_pool
,
host_indices
,
device_indices
,
io_backend
)
->
None
:
"""
Backup KV data from the device memory pool to the host memory pool for all layers.
"""
raise
NotImplementedError
()
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
"""
"""
...
@@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
host_size
:
int
,
page_size
:
int
,
page_size
:
int
,
layout
:
str
,
pin_memory
:
bool
=
True
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
super
().
__init__
(
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
device_pool
,
host_to_device_ratio
,
host_size
,
page_size
,
layout
,
pin_memory
,
device
,
)
self
.
k_data_refs
=
[
self
.
k_buffer
[
i
]
for
i
in
range
(
self
.
layer_num
)]
self
.
v_data_refs
=
[
self
.
v_buffer
[
i
]
for
i
in
range
(
self
.
layer_num
)]
self
.
k_data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
k_data_refs
],
dtype
=
torch
.
uint64
,
device
=
self
.
device_pool
.
device
,
)
self
.
v_data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
v_data_refs
],
dtype
=
torch
.
uint64
,
device
=
self
.
device_pool
.
device
,
)
)
def
get_size_per_token
(
self
):
def
get_size_per_token
(
self
):
...
@@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache):
return
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
return
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
def
init_kv_buffer
(
self
):
def
init_kv_buffer
(
self
):
if
self
.
layout
==
"layer_first"
:
dims
=
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
)
elif
self
.
layout
==
"page_first"
:
dims
=
(
2
,
self
.
size
,
self
.
layer_num
,
self
.
head_num
,
self
.
head_dim
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
self
.
token_stride_size
=
self
.
head_num
*
self
.
head_dim
*
self
.
dtype
.
itemsize
self
.
layout_dim
=
self
.
token_stride_size
*
self
.
layer_num
return
torch
.
empty
(
return
torch
.
empty
(
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
)
,
dims
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
# todo, page first memory layout
@
property
def
k_buffer
(
self
):
return
self
.
kv_buffer
[
0
]
@
property
def
v_buffer
(
self
):
return
self
.
kv_buffer
[
1
]
def
load_to_device_per_layer
(
self
,
device_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
,
):
if
io_backend
==
"kernel"
:
if
self
.
layout
==
"layer_first"
:
transfer_kv_per_layer
(
src_k
=
self
.
k_buffer
[
layer_id
],
dst_k
=
device_pool
.
k_buffer
[
layer_id
],
src_v
=
self
.
v_buffer
[
layer_id
],
dst_v
=
device_pool
.
v_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
item_size
=
self
.
token_stride_size
,
)
elif
self
.
layout
==
"page_first"
:
transfer_kv_per_layer_pf_lf
(
src_k
=
self
.
k_buffer
,
dst_k
=
device_pool
.
k_buffer
[
layer_id
],
src_v
=
self
.
v_buffer
,
dst_v
=
device_pool
.
v_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
item_size
=
self
.
token_stride_size
,
src_layout_dim
=
self
.
layout_dim
,
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
elif
io_backend
==
"direct"
:
assert
(
self
.
layout
==
"layer_first"
),
f
"Direct IO backend only supports layer_first layout."
transfer_kv_direct
(
src_layers
=
[
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]],
dst_layers
=
[
device_pool
.
k_buffer
[
layer_id
],
device_pool
.
v_buffer
[
layer_id
],
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
page_size
=
self
.
page_size
,
)
else
:
raise
ValueError
(
f
"Unsupported IO backend:
{
io_backend
}
"
)
def
backup_from_device_all_layer
(
self
,
device_pool
,
host_indices
,
device_indices
,
io_backend
):
if
io_backend
==
"kernel"
:
if
self
.
layout
==
"layer_first"
:
transfer_kv_all_layer
(
src_k_layers
=
device_pool
.
k_data_ptrs
,
dst_k_layers
=
self
.
k_data_ptrs
,
src_v_layers
=
device_pool
.
v_data_ptrs
,
dst_v_layers
=
self
.
v_data_ptrs
,
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
item_size
=
self
.
token_stride_size
,
num_layers
=
self
.
layer_num
,
)
elif
self
.
layout
==
"page_first"
:
transfer_kv_all_layer_lf_pf
(
src_k_layers
=
device_pool
.
k_data_ptrs
,
dst_k
=
self
.
k_buffer
,
src_v_layers
=
device_pool
.
v_data_ptrs
,
dst_v
=
self
.
v_buffer
,
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
item_size
=
self
.
token_stride_size
,
dst_layout_dim
=
self
.
layout_dim
,
num_layers
=
self
.
layer_num
,
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
elif
io_backend
==
"direct"
:
assert
(
self
.
layout
==
"layer_first"
),
f
"Direct IO backend only supports layer_first layout."
transfer_kv_direct
(
src_layers
=
device_pool
.
k_buffer
+
device_pool
.
v_buffer
,
dst_layers
=
self
.
k_data_refs
+
self
.
v_data_refs
,
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
page_size
=
self
.
page_size
,
)
else
:
raise
ValueError
(
f
"Unsupported IO backend:
{
io_backend
}
"
)
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
return
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
if
self
.
layout
==
"layer_first"
:
return
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
elif
self
.
layout
==
"page_first"
:
return
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:,
:].
flatten
()
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
def
get_dummy_flat_data_page
(
self
)
->
torch
.
Tensor
:
def
get_dummy_flat_data_page
(
self
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
return
torch
.
zeros
(
...
@@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
).
flatten
()
).
flatten
()
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
if
self
.
layout
==
"layer_first"
:
2
,
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:]
=
(
self
.
layer_num
,
data_page
.
reshape
(
self
.
page_size
,
2
,
self
.
head_num
,
self
.
layer_num
,
self
.
head_dim
,
self
.
page_size
,
)
self
.
head_num
,
self
.
head_dim
,
)
)
elif
self
.
layout
==
"page_first"
:
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:,
:]
=
(
data_page
.
reshape
(
2
,
self
.
page_size
,
self
.
layer_num
,
self
.
head_num
,
self
.
head_dim
)
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
def
get_buffer_meta
(
self
,
keys
,
indices
):
def
get_buffer_meta
(
self
,
keys
,
indices
):
ptr_list
=
[]
ptr_list
=
[]
...
@@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list
=
[
element_size
]
*
len
(
key_list
)
element_size_list
=
[
element_size
]
*
len
(
key_list
)
return
key_list
,
ptr_list
,
element_size_list
return
key_list
,
ptr_list
,
element_size_list
@
property
def
k_buffer
(
self
):
return
self
.
kv_buffer
[
0
]
@
property
def
v_buffer
(
self
):
return
self
.
kv_buffer
[
1
]
class
MLATokenToKVPoolHost
(
HostKVCache
):
class
MLATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MLATokenToKVPool
device_pool
:
MLATokenToKVPool
...
@@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
host_size
:
int
,
page_size
:
int
,
page_size
:
int
,
layout
:
str
,
pin_memory
:
bool
=
True
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
super
().
__init__
(
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
device_pool
,
host_to_device_ratio
,
host_size
,
page_size
,
layout
,
pin_memory
,
device
,
)
self
.
data_refs
=
[
self
.
kv_buffer
[
i
]
for
i
in
range
(
self
.
layer_num
)]
self
.
data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
data_refs
],
dtype
=
torch
.
uint64
,
device
=
self
.
device_pool
.
device
,
)
)
def
get_size_per_token
(
self
):
def
get_size_per_token
(
self
):
...
@@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache):
)
)
def
init_kv_buffer
(
self
):
def
init_kv_buffer
(
self
):
return
torch
.
empty
(
if
self
.
layout
==
"layer_first"
:
(
dims
=
(
self
.
layer_num
,
self
.
layer_num
,
self
.
size
,
self
.
size
,
1
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
),
)
elif
self
.
layout
==
"page_first"
:
dims
=
(
self
.
size
,
self
.
layer_num
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
self
.
token_stride_size
=
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
*
self
.
dtype
.
itemsize
self
.
layout_dim
=
self
.
token_stride_size
*
self
.
layer_num
return
torch
.
empty
(
dims
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
def
load_to_device_per_layer
(
self
,
device_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
if
io_backend
==
"kernel"
:
if
self
.
layout
==
"layer_first"
:
transfer_kv_per_layer_mla
(
src
=
self
.
kv_buffer
[
layer_id
],
dst
=
device_pool
.
kv_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
item_size
=
self
.
token_stride_size
,
)
elif
self
.
layout
==
"page_first"
:
transfer_kv_per_layer_mla_pf_lf
(
src
=
self
.
kv_buffer
,
dst
=
device_pool
.
kv_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
item_size
=
self
.
token_stride_size
,
src_layout_dim
=
self
.
layout_dim
,
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
elif
io_backend
==
"direct"
:
assert
(
self
.
layout
==
"layer_first"
),
f
"Direct IO backend only supports layer_first layout."
transfer_kv_direct
(
src_layers
=
[
self
.
kv_buffer
[
layer_id
]],
dst_layers
=
[
device_pool
.
kv_buffer
[
layer_id
]],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
page_size
=
self
.
page_size
,
)
def
backup_from_device_all_layer
(
self
,
device_pool
,
host_indices
,
device_indices
,
io_backend
):
if
io_backend
==
"kernel"
:
if
self
.
layout
==
"layer_first"
:
transfer_kv_all_layer_mla
(
src_layers
=
device_pool
.
data_ptrs
,
dst_layers
=
self
.
data_ptrs
,
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
item_size
=
self
.
token_stride_size
,
num_layers
=
self
.
layer_num
,
)
elif
self
.
layout
==
"page_first"
:
transfer_kv_all_layer_mla_lf_pf
(
src_layers
=
device_pool
.
data_ptrs
,
dst_k
=
self
.
kv_buffer
,
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
item_size
=
self
.
token_stride_size
,
dst_layout_dim
=
self
.
layout_dim
,
num_layers
=
self
.
layer_num
,
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
elif
io_backend
==
"direct"
:
assert
(
self
.
layout
==
"layer_first"
),
f
"Direct IO backend only supports layer_first layout."
transfer_kv_direct
(
src_layers
=
device_pool
.
kv_buffer
,
dst_layers
=
self
.
data_refs
,
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
page_size
=
self
.
page_size
,
)
else
:
raise
ValueError
(
f
"Unsupported IO backend:
{
io_backend
}
"
)
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
return
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
if
self
.
layout
==
"layer_first"
:
return
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
elif
self
.
layout
==
"page_first"
:
return
self
.
kv_buffer
[
index
:
index
+
self
.
page_size
,
:,
:,
:].
flatten
()
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
def
get_dummy_flat_data_page
(
self
)
->
torch
.
Tensor
:
def
get_dummy_flat_data_page
(
self
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
return
torch
.
zeros
(
...
@@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache):
).
flatten
()
).
flatten
()
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
if
self
.
layout
==
"layer_first"
:
self
.
layer_num
,
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
self
.
page_size
,
self
.
layer_num
,
1
,
self
.
page_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
1
,
)
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
elif
self
.
layout
==
"page_first"
:
self
.
kv_buffer
[
index
:
index
+
self
.
page_size
,
:,
:,
:]
=
data_page
.
reshape
(
self
.
page_size
,
self
.
layer_num
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
else
:
raise
ValueError
(
f
"Unsupported layout:
{
self
.
layout
}
"
)
def
get_buffer_meta
(
self
,
keys
,
indices
):
def
get_buffer_meta
(
self
,
keys
,
indices
):
ptr_list
=
[]
ptr_list
=
[]
...
...
python/sglang/srt/server_args.py
View file @
dd7ca006
...
@@ -198,7 +198,8 @@ class ServerArgs:
...
@@ -198,7 +198,8 @@ class ServerArgs:
hicache_ratio
:
float
=
2.0
hicache_ratio
:
float
=
2.0
hicache_size
:
int
=
0
hicache_size
:
int
=
0
hicache_write_policy
:
str
=
"write_through_selective"
hicache_write_policy
:
str
=
"write_through_selective"
hicache_io_backend
:
str
=
""
hicache_io_backend
:
str
=
"kernel"
hicache_mem_layout
:
str
=
"layer_first"
hicache_storage_backend
:
Optional
[
str
]
=
None
hicache_storage_backend
:
Optional
[
str
]
=
None
# Double Sparsity
# Double Sparsity
...
@@ -1487,6 +1488,14 @@ class ServerArgs:
...
@@ -1487,6 +1488,14 @@ class ServerArgs:
default
=
ServerArgs
.
hicache_io_backend
,
default
=
ServerArgs
.
hicache_io_backend
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
)
)
parser
.
add_argument
(
"--hicache-mem-layout"
,
type
=
str
,
choices
=
[
"layer_first"
,
"page_first"
],
default
=
ServerArgs
.
hicache_mem_layout
,
help
=
"The layout of host memory pool for hierarchical cache."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--hicache-storage-backend"
,
"--hicache-storage-backend"
,
type
=
str
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment