Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fc824b8
Unverified
Commit
2fc824b8
authored
Jul 06, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 06, 2025
Browse files
Kernels for efficient KV cache IO (#7313)
parent
253454de
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
184 additions
and
371 deletions
+184
-371
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+41
-195
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+2
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+113
-63
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+6
-109
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+8
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
2fc824b8
...
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import
concurrent.futures
import
logging
import
math
import
threading
...
...
@@ -169,12 +168,23 @@ class HiCacheController:
page_size
:
int
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
page_size
=
page_size
# using kernel for small page KV cache transfer and DMA for large pages
if
not
io_backend
:
IO_BACKEND_PAGE_SIZE_THRESHOLD
=
64
self
.
io_backend
=
(
"direct"
if
self
.
page_size
>=
IO_BACKEND_PAGE_SIZE_THRESHOLD
else
"kernel"
)
else
:
self
.
io_backend
=
io_backend
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
...
...
@@ -203,12 +213,7 @@ class HiCacheController:
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
write_thread
=
threading
.
Thread
(
target
=
(
self
.
write_thread_func_buffer
if
self
.
page_size
==
1
else
self
.
write_thread_func_direct
),
daemon
=
True
,
target
=
self
.
write_thread_func_direct
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
...
...
@@ -229,12 +234,7 @@ class HiCacheController:
self
.
ack_load_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
target
=
(
self
.
write_thread_func_buffer
if
self
.
page_size
==
1
else
self
.
write_thread_func_direct
),
daemon
=
True
,
target
=
self
.
write_thread_func_direct
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
...
...
@@ -281,6 +281,15 @@ class HiCacheController:
)
return
device_indices
def
move_indices
(
self
,
host_indices
,
device_indices
):
# move indices to GPU if using kernels, to host if using direct indexing
if
self
.
io_backend
==
"kernel"
:
return
host_indices
.
to
(
self
.
mem_pool_device
.
device
),
device_indices
elif
self
.
io_backend
==
"direct"
:
return
host_indices
,
device_indices
.
cpu
()
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
write_thread_func_direct
(
self
):
"""
Directly write through KV caches to host memory without buffering.
...
...
@@ -289,10 +298,14 @@ class HiCacheController:
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
self
.
mem_pool_host
.
write_page_all_layers
(
operation
.
host_indices
,
operation
.
device_indices
,
self
.
mem_pool_device
,
host_indices
,
device_indices
=
self
.
move_indices
(
operation
.
host_indices
,
operation
.
device_indices
)
self
.
mem_pool_device
.
backup_to_host_all_layer
(
self
.
mem_pool_host
,
host_indices
,
device_indices
,
self
.
io_backend
,
)
self
.
write_stream
.
synchronize
()
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
...
...
@@ -304,27 +317,6 @@ class HiCacheController:
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_direct
(
self
):
"""
Directly load KV caches from host memory to device memory without buffering.
"""
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
)
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_layer_by_layer
(
self
):
"""
Load KV caches from host memory to device memory layer by layer.
...
...
@@ -349,22 +341,18 @@ class HiCacheController:
# start layer-wise KV cache transfer from CPU to GPU
self
.
layer_done_counter
.
reset
()
host_indices
,
device_indices
=
self
.
move_indices
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
)
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
if
self
.
page_size
==
1
:
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
else
:
self
.
mem_pool_host
.
load_page_per_layer
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
,
self
.
mem_pool_device
,
i
,
)
self
.
load_stream
.
synchronize
()
self
.
mem_pool_device
.
load_from_host_per_layer
(
self
.
mem_pool_host
,
host_indices
,
device_indices
,
i
,
self
.
io_backend
,
)
self
.
load_stream
.
synchronize
()
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
...
...
@@ -372,148 +360,6 @@ class HiCacheController:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
def
write_aux_func
(
self
,
no_wait
=
False
):
"""
Auxiliary function to prepare the buffer for write operations.
"""
torch
.
cuda
.
set_stream
(
self
.
write_stream
)
def
_to_op
(
op_
):
assert
op_
.
device_indices
.
is_cuda
,
"Device indices should be on GPU"
op_
.
data
=
self
.
mem_pool_device
.
get_flat_data
(
op_
.
device_indices
).
to
(
self
.
mem_pool_host
.
device
)
self
.
write_buffer
.
put
(
op_
)
return
op_
buffer
=
None
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
factor
=
(
len
(
operation
.
device_indices
)
//
self
.
write_buffer
.
max_buffer_size
)
if
factor
>=
1
:
if
buffer
is
not
None
:
_to_op
(
buffer
)
buffer
=
None
if
factor
<
2
:
_to_op
(
operation
)
else
:
split_ops
=
operation
.
split
(
factor
)
for
op_
in
split_ops
:
_to_op
(
op_
)
continue
if
buffer
is
None
:
buffer
=
operation
else
:
buffer
.
merge
(
operation
)
if
(
no_wait
or
len
(
buffer
.
host_indices
)
>=
self
.
write_buffer
.
max_buffer_size
or
self
.
write_queue
.
empty
()
or
self
.
write_buffer
.
empty
()
):
_to_op
(
buffer
)
buffer
=
None
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_aux_func
(
self
):
"""
Auxiliary function to prepare the buffer for load operations.
"""
def
_pin_op
(
op_
,
put
=
True
):
op_
.
data
=
(
self
.
mem_pool_host
.
get_flat_data
(
op_
.
host_indices
)
.
contiguous
()
.
pin_memory
()
)
if
put
:
self
.
load_buffer
.
put
(
op_
)
return
op_
buffer
=
None
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
factor
=
len
(
operation
.
host_indices
)
//
self
.
load_buffer
.
max_buffer_size
if
factor
>=
1
:
if
buffer
is
not
None
:
_pin_op
(
buffer
)
buffer
=
None
if
factor
<
2
:
_pin_op
(
operation
)
else
:
split_ops
=
operation
.
split
(
factor
)
split_args
=
[(
op_
,
True
)
for
op_
in
split_ops
[:
-
1
]]
split_args
.
append
((
split_ops
[
-
1
],
False
))
# Spawn threads to pin each op concurrently
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
executor
:
pinned_ops
=
list
(
executor
.
map
(
lambda
x
:
_pin_op
(
x
[
0
],
put
=
x
[
1
]),
split_args
)
)
# preserve the order of last op to ensure correct ack
self
.
load_buffer
.
put
(
pinned_ops
[
-
1
])
continue
if
buffer
is
None
:
buffer
=
operation
else
:
buffer
.
merge
(
operation
)
if
(
len
(
buffer
.
host_indices
)
>=
self
.
load_buffer
.
max_buffer_size
or
self
.
load_queue
.
empty
()
or
self
.
load_buffer
.
empty
()
):
_pin_op
(
buffer
)
buffer
=
None
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
# todo (zhiqiang): double buffering to be deprecated
def
write_thread_func_buffer
(
self
):
aux_thread
=
threading
.
Thread
(
target
=
self
.
write_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
while
not
self
.
stop_event
.
is_set
():
operation
=
self
.
write_buffer
.
get
()
if
operation
is
None
:
continue
self
.
mem_pool_host
.
assign_flat_data
(
operation
.
host_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_write_queue
.
put
(
node_id
)
aux_thread
.
join
()
def
load_thread_func_buffer
(
self
):
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
aux_thread
=
threading
.
Thread
(
target
=
self
.
load_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
while
not
self
.
stop_event
.
is_set
():
operation
=
self
.
load_buffer
.
get
()
if
operation
is
None
:
continue
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
aux_thread
.
join
()
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
)
->
int
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
2fc824b8
...
...
@@ -591,6 +591,12 @@ class Scheduler(
hicache_ratio
=
server_args
.
hicache_ratio
,
hicache_size
=
server_args
.
hicache_size
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
hicache_io_backend
=
(
"direct"
if
server_args
.
attention_backend
==
"fa3"
# hot fix for incompatibility
else
server_args
.
hicache_io_backend
),
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
2fc824b8
...
...
@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
hicache_ratio
:
float
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
hicache_io_backend
:
str
,
):
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
...
...
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
page_size
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
io_backend
=
hicache_io_backend
,
)
# record the nodes with ongoing write through
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
2fc824b8
...
...
@@ -34,10 +34,11 @@ import torch
import
torch.distributed
as
dist
import
triton
import
triton.language
as
tl
from
sgl_kernel.kvcacheio
import
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
debug_timing
,
get_bool_env_var
,
is_cuda
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
next_power_of_2
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
)
->
None
:
raise
NotImplementedError
()
def
get_flat_data
(
self
,
indices
):
raise
NotImplementedErro
r
(
)
def
transfer
(
self
,
indices
,
flat_data
):
@
abc
.
abstractmethod
def
load_from_host_per_laye
r
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
raise
NotImplementedError
()
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
@
abc
.
abstractmethod
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
raise
NotImplementedError
()
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
...
...
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
)
for
_
in
range
(
self
.
layer_num
)
]
self
.
token_stride
=
self
.
head_num
*
self
.
head_dim
self
.
data_ptrs
=
torch
.
tensor
(
[
x
.
data_ptr
()
for
x
in
self
.
k_buffer
+
self
.
v_buffer
],
dtype
=
torch
.
uint64
,
...
...
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs
=
[
self
.
get_key_buffer
(
i
).
data_ptr
()
self
.
_
get_key_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
).
data_ptr
()
self
.
_
get_value_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
)
]
kv_data_lens
=
[
self
.
get_key_buffer
(
i
).
nbytes
self
.
_
get_key_buffer
(
i
).
nbytes
for
i
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
).
nbytes
self
.
_
get_value_buffer
(
i
).
nbytes
for
i
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
)
]
kv_item_lens
=
[
self
.
get_key_buffer
(
i
)[
0
].
nbytes
*
self
.
page_size
self
.
_
get_key_buffer
(
i
)[
0
].
nbytes
*
self
.
page_size
for
i
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
)[
0
].
nbytes
*
self
.
page_size
self
.
_
get_value_buffer
(
i
)[
0
].
nbytes
*
self
.
page_size
for
i
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
)
]
return
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
...
...
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
self
.
v_buffer
[
layer_id
][
chunk_indices
]
=
v_chunk
torch
.
cuda
.
synchronize
()
# Todo: different memory layout
def
get_flat_data
(
self
,
indices
):
# prepare a large chunk of contiguous data for efficient transfer
flatten
=
torch
.
stack
(
[
torch
.
stack
([
self
.
k_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)]),
torch
.
stack
([
self
.
v_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)]),
]
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
,
):
transfer_kv_per_layer
(
src_k
=
host_pool
.
k_buffer
[
layer_id
],
dst_k
=
self
.
k_buffer
[
layer_id
],
src_v
=
host_pool
.
v_buffer
[
layer_id
],
dst_v
=
self
.
v_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
return
flatten
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
for
i
in
range
(
self
.
layer_num
):
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
v_data
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for
layer_id
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
):
if
layer_id
-
self
.
start_layer
>=
len
(
host_pool
.
k_buffer
):
raise
ValueError
(
f
"Layer ID
{
layer_id
}
exceeds the number of layers in host pool."
)
transfer_kv_per_layer
(
src_k
=
self
.
k_buffer
[
layer_id
],
dst_k
=
host_pool
.
k_buffer
[
layer_id
],
src_v
=
self
.
v_buffer
[
layer_id
],
dst_v
=
host_pool
.
v_buffer
[
layer_id
],
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
_get_key_buffer
(
self
,
layer_id
:
int
):
# for internal use of referencing
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
# it is supposed to be used only by attention backend not for information purpose
# same applies to get_value_buffer and get_kv_buffer
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
return
self
.
_get_key_buffer
(
layer_id
)
def
_get_value_buffer
(
self
,
layer_id
:
int
):
# for internal use of referencing
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
return
self
.
_get_value_buffer
(
layer_id
)
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
...
...
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
for
_
in
range
(
layer_num
)
]
self
.
token_stride
=
kv_lora_rank
+
qk_rope_head_dim
self
.
layer_transfer_counter
=
None
kv_size
=
self
.
get_kv_size_bytes
()
...
...
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
self
.
kv_buffer
[
layer_id
],
loc
,
cache_k_nope
,
cache_k_rope
)
def
get_flat_data
(
self
,
indices
):
# prepare a large chunk of contiguous data for efficient transfer
return
torch
.
stack
([
self
.
kv_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)])
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
for
i
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
i
][
indices
]
=
flat_data
[
i
]
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
transfer_kv_per_layer_mla
(
src
=
host_pool
.
kv_buffer
[
layer_id
],
dst
=
self
.
kv_buffer
[
layer_id
],
src_indices
=
host_indices
,
dst_indices
=
device_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
flat_data
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for
layer_id
in
range
(
self
.
start_layer
,
self
.
start_layer
+
self
.
layer_num
):
if
layer_id
-
self
.
start_layer
>=
len
(
host_pool
.
kv_buffer
):
raise
ValueError
(
f
"Layer ID
{
layer_id
}
exceeds the number of layers in host pool."
)
transfer_kv_per_layer_mla
(
src
=
self
.
kv_buffer
[
layer_id
],
dst
=
host_pool
.
kv_buffer
[
layer_id
],
src_indices
=
device_indices
,
dst_indices
=
host_indices
,
io_backend
=
io_backend
,
page_size
=
self
.
page_size
,
item_size
=
self
.
token_stride
,
)
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
...
...
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_label
def
get_flat_data
(
self
,
indices
):
pass
def
transfer
(
self
,
indices
,
flat_data
):
pass
def
load_from_host_per_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
layer_id
,
io_backend
):
raise
NotImplementedError
(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
pass
def
backup_to_host_all_layer
(
self
,
host_pool
,
host_indices
,
device_indices
,
io_backend
):
raise
NotImplementedError
(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
@
triton
.
jit
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
2fc824b8
...
...
@@ -8,7 +8,6 @@ import psutil
import
torch
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
MHATokenToKVPool
,
MLATokenToKVPool
from
sglang.srt.utils
import
debug_timing
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
def
init_kv_buffer
(
self
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
transfer
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data
(
self
,
indices
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
assign_flat_data
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
synchronized
()
def
clear
(
self
):
# Initialize memory states and tracking structures.
...
...
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory
=
self
.
pin_memory
,
)
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
@
property
def
k_buffer
(
self
):
return
self
.
kv_buffer
[
0
]
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
0
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
k_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
self
.
kv_buffer
[
1
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
v_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
device_pool
.
v_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
@
property
def
v_buffer
(
self
):
return
self
.
kv_buffer
[
1
]
class
MLATokenToKVPoolHost
(
HostKVCache
):
...
...
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
kv_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
python/sglang/srt/mem_cache/radix_cache.py
View file @
2fc824b8
...
...
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
if
self
.
page_size
!=
1
:
page_aligned_len
=
len
(
kv_indices
)
//
self
.
page_size
*
self
.
page_size
page_aligned_kv_indices
=
kv_indices
[:
page_aligned_len
].
clone
()
page_aligned_kv_indices
=
kv_indices
[:
page_aligned_len
].
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
page_aligned_len
:])
else
:
page_aligned_len
=
len
(
kv_indices
)
page_aligned_kv_indices
=
kv_indices
.
clone
(
)
page_aligned_kv_indices
=
kv_indices
.
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
...
...
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
if
self
.
page_size
!=
1
:
page_aligned_len
=
len
(
kv_indices
)
//
self
.
page_size
*
self
.
page_size
page_aligned_kv_indices
=
kv_indices
[:
page_aligned_len
].
clone
()
page_aligned_kv_indices
=
kv_indices
[:
page_aligned_len
].
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
else
:
page_aligned_len
=
len
(
kv_indices
)
page_aligned_kv_indices
=
kv_indices
.
clone
(
)
page_aligned_kv_indices
=
kv_indices
.
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
page_aligned_token_ids
=
token_ids
[:
page_aligned_len
]
# Radix Cache takes one ref in memory pool
...
...
python/sglang/srt/server_args.py
View file @
2fc824b8
...
...
@@ -217,6 +217,7 @@ class ServerArgs:
hicache_ratio
:
float
=
2.0
hicache_size
:
int
=
0
hicache_write_policy
:
str
=
"write_through_selective"
hicache_io_backend
:
str
=
""
flashinfer_mla_disable_ragged
:
bool
=
False
disable_shared_experts_fusion
:
bool
=
False
disable_chunked_prefix_cache
:
bool
=
False
...
...
@@ -1530,6 +1531,13 @@ class ServerArgs:
default
=
ServerArgs
.
hicache_write_policy
,
help
=
"The write policy of hierarchical cache."
,
)
parser
.
add_argument
(
"--hicache-io-backend"
,
type
=
str
,
choices
=
[
"direct"
,
"kernel"
],
default
=
ServerArgs
.
hicache_io_backend
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
)
parser
.
add_argument
(
"--flashinfer-mla-disable-ragged"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment