Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ea5d221
"tests/vscode:/vscode.git/clone" did not exist on "93b39729ee02693cd6315dd4dadd8e3e624e1d6b"
Unverified
Commit
5ea5d221
authored
Jun 22, 2025
by
Liangsheng Yin
Committed by
GitHub
Jun 22, 2025
Browse files
Fix CPU offloading for MLA memory pool (#7409)
parent
fdfd5224
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
8 deletions
+44
-8
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+44
-8
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
5ea5d221
...
...
@@ -123,6 +123,9 @@ class KVCache(abc.ABC):
enable
=
enable_memory_saver
)
# used for chunked cpu-offloading
self
.
cpu_offloading_chunk_size
=
8192
@
abc
.
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
...
...
@@ -157,6 +160,12 @@ class KVCache(abc.ABC):
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
def
get_cpu_copy
(
self
,
indices
):
raise
NotImplementedError
()
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
raise
NotImplementedError
()
class
TokenToKVPoolAllocator
:
"""An allocator managing the indices to kv cache data."""
...
...
@@ -280,8 +289,6 @@ class MHATokenToKVPool(KVCache):
self
.
_create_buffers
()
# used for chunked cpu-offloading
self
.
chunk_size
=
8192
self
.
layer_transfer_counter
=
None
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
alt_stream
=
self
.
device_module
.
Stream
()
if
_is_cuda
else
None
...
...
@@ -378,10 +385,11 @@ class MHATokenToKVPool(KVCache):
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
kv_cache_cpu
=
[]
chunk_size
=
self
.
cpu_offloading_chunk_size
for
layer_id
in
range
(
self
.
layer_num
):
kv_cache_cpu
.
append
([])
for
i
in
range
(
0
,
len
(
indices
),
self
.
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
self
.
chunk_size
]
for
i
in
range
(
0
,
len
(
indices
),
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
chunk_size
]
k_cpu
=
self
.
k_buffer
[
layer_id
][
chunk_indices
].
to
(
"cpu"
,
non_blocking
=
True
)
...
...
@@ -394,12 +402,13 @@ class MHATokenToKVPool(KVCache):
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
torch
.
cuda
.
synchronize
()
chunk_size
=
self
.
cpu_offloading_chunk_size
for
layer_id
in
range
(
self
.
layer_num
):
for
i
in
range
(
0
,
len
(
indices
),
self
.
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
self
.
chunk_size
]
for
i
in
range
(
0
,
len
(
indices
),
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
chunk_size
]
k_cpu
,
v_cpu
=
(
kv_cache_cpu
[
layer_id
][
i
//
self
.
chunk_size
][
0
],
kv_cache_cpu
[
layer_id
][
i
//
self
.
chunk_size
][
1
],
kv_cache_cpu
[
layer_id
][
i
//
chunk_size
][
0
],
kv_cache_cpu
[
layer_id
][
i
//
chunk_size
][
1
],
)
assert
k_cpu
.
shape
[
0
]
==
v_cpu
.
shape
[
0
]
==
len
(
chunk_indices
)
k_chunk
=
k_cpu
.
to
(
self
.
k_buffer
[
0
].
device
,
non_blocking
=
True
)
...
...
@@ -724,6 +733,33 @@ class MLATokenToKVPool(KVCache):
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
flat_data
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
kv_cache_cpu
=
[]
chunk_size
=
self
.
cpu_offloading_chunk_size
for
layer_id
in
range
(
self
.
layer_num
):
kv_cache_cpu
.
append
([])
for
i
in
range
(
0
,
len
(
indices
),
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
chunk_size
]
kv_cpu
=
self
.
kv_buffer
[
layer_id
][
chunk_indices
].
to
(
"cpu"
,
non_blocking
=
True
)
kv_cache_cpu
[
-
1
].
append
(
kv_cpu
)
torch
.
cuda
.
synchronize
()
return
kv_cache_cpu
def
load_cpu_copy
(
self
,
kv_cache_cpu
,
indices
):
torch
.
cuda
.
synchronize
()
chunk_size
=
self
.
cpu_offloading_chunk_size
for
layer_id
in
range
(
self
.
layer_num
):
for
i
in
range
(
0
,
len
(
indices
),
chunk_size
):
chunk_indices
=
indices
[
i
:
i
+
chunk_size
]
kv_cpu
=
kv_cache_cpu
[
layer_id
][
i
//
chunk_size
]
assert
kv_cpu
.
shape
[
0
]
==
len
(
chunk_indices
)
kv_chunk
=
kv_cpu
.
to
(
self
.
kv_buffer
[
0
].
device
,
non_blocking
=
True
)
self
.
kv_buffer
[
layer_id
][
chunk_indices
]
=
kv_chunk
torch
.
cuda
.
synchronize
()
class
DoubleSparseTokenToKVPool
(
KVCache
):
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment