Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
273b2834
Unverified
Commit
273b2834
authored
Sep 06, 2025
by
Xinyuan Tong
Committed by
GitHub
Sep 05, 2025
Browse files
[Minor] Refactors KV memory pool (#9842)
parent
f84db115
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
62 deletions
+60
-62
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+35
-43
test/srt/test_swa_unittest.py
test/srt/test_swa_unittest.py
+25
-19
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
273b2834
...
@@ -130,6 +130,29 @@ class KVCache(abc.ABC):
...
@@ -130,6 +130,29 @@ class KVCache(abc.ABC):
# used for chunked cpu-offloading
# used for chunked cpu-offloading
self
.
cpu_offloading_chunk_size
=
8192
self
.
cpu_offloading_chunk_size
=
8192
# default state for optional layer-wise transfer control
self
.
layer_transfer_counter
=
None
def
_finalize_allocation_log
(
self
,
num_tokens
:
int
):
"""Common logging and mem_usage computation for KV cache allocation.
Supports both tuple (K, V) size returns and single KV size returns.
"""
kv_size_bytes
=
self
.
get_kv_size_bytes
()
if
isinstance
(
kv_size_bytes
,
tuple
):
k_size
,
v_size
=
kv_size_bytes
k_size_GB
=
k_size
/
GB
v_size_GB
=
v_size
/
GB
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
num_tokens
}
, K size:
{
k_size_GB
:.
2
f
}
GB, V size:
{
v_size_GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
k_size_GB
+
v_size_GB
else
:
kv_size_GB
=
kv_size_bytes
/
GB
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
num_tokens
}
, KV size:
{
kv_size_GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
kv_size_GB
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -205,15 +228,9 @@ class MHATokenToKVPool(KVCache):
...
@@ -205,15 +228,9 @@ class MHATokenToKVPool(KVCache):
self
.
_create_buffers
()
self
.
_create_buffers
()
self
.
layer_transfer_counter
=
None
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
device_module
=
torch
.
get_device_module
(
self
.
device
)
self
.
alt_stream
=
self
.
device_module
.
Stream
()
if
_is_cuda
else
None
self
.
alt_stream
=
self
.
device_module
.
Stream
()
if
_is_cuda
else
None
self
.
_finalize_allocation_log
(
size
)
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
(
k_size
+
v_size
)
/
GB
def
_create_buffers
(
self
):
def
_create_buffers
(
self
):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
...
@@ -427,43 +444,30 @@ class SWAKVPool(KVCache):
...
@@ -427,43 +444,30 @@ class SWAKVPool(KVCache):
self
,
self
,
size
:
int
,
size
:
int
,
size_swa
:
int
,
size_swa
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
swa_attention_layer_ids
:
List
[
int
],
swa_attention_layer_ids
:
List
[
int
],
full_attention_layer_ids
:
List
[
int
],
full_attention_layer_ids
:
List
[
int
],
enable_kvcache_transpose
:
bool
,
enable_kvcache_transpose
:
bool
,
device
:
str
,
token_to_kv_pool_class
:
KVCache
=
MHATokenToKVPool
,
**
kwargs
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
size_swa
=
size_swa
self
.
size_swa
=
size_swa
self
.
dtype
=
dtype
self
.
device
=
device
self
.
swa_layer_nums
=
len
(
swa_attention_layer_ids
)
self
.
swa_layer_nums
=
len
(
swa_attention_layer_ids
)
self
.
full_layer_nums
=
len
(
full_attention_layer_ids
)
self
.
full_layer_nums
=
len
(
full_attention_layer_ids
)
self
.
page_size
=
1
kwargs
[
"page_size"
]
=
1
kwargs
[
"enable_memory_saver"
]
=
False
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert
not
enable_kvcache_transpose
assert
not
enable_kvcache_transpose
TokenToKVPoolClass
=
MHATokenToKVPool
self
.
swa_kv_pool
=
T
oken
ToKVP
ool
C
lass
(
self
.
swa_kv_pool
=
t
oken
_to_kv_p
ool
_c
lass
(
size
=
size_swa
,
size
=
size_swa
,
page_size
=
self
.
page_size
,
dtype
=
dtype
,
head_num
=
head_num
,
head_dim
=
head_dim
,
layer_num
=
self
.
swa_layer_nums
,
layer_num
=
self
.
swa_layer_nums
,
device
=
device
,
**
kwargs
,
enable_memory_saver
=
False
,
)
)
self
.
full_kv_pool
=
T
oken
ToKVP
ool
C
lass
(
self
.
full_kv_pool
=
t
oken
_to_kv_p
ool
_c
lass
(
size
=
size
,
size
=
size
,
page_size
=
self
.
page_size
,
dtype
=
dtype
,
head_num
=
head_num
,
head_dim
=
head_dim
,
layer_num
=
self
.
full_layer_nums
,
layer_num
=
self
.
full_layer_nums
,
device
=
device
,
**
kwargs
,
enable_memory_saver
=
False
,
)
)
self
.
layers_mapping
:
Dict
[
int
,
Tuple
[
int
,
bool
]]
=
{}
self
.
layers_mapping
:
Dict
[
int
,
Tuple
[
int
,
bool
]]
=
{}
for
full_attn_layer_id
,
global_layer_id
in
enumerate
(
full_attention_layer_ids
):
for
full_attn_layer_id
,
global_layer_id
in
enumerate
(
full_attention_layer_ids
):
...
@@ -768,13 +772,7 @@ class MLATokenToKVPool(KVCache):
...
@@ -768,13 +772,7 @@ class MLATokenToKVPool(KVCache):
dtype
=
torch
.
uint64
,
dtype
=
torch
.
uint64
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
layer_transfer_counter
=
None
self
.
_finalize_allocation_log
(
size
)
kv_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, KV size:
{
kv_size
/
GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
kv_size
/
GB
def
get_kv_size_bytes
(
self
):
def
get_kv_size_bytes
(
self
):
assert
hasattr
(
self
,
"kv_buffer"
)
assert
hasattr
(
self
,
"kv_buffer"
)
...
@@ -936,13 +934,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -936,13 +934,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
layer_transfer_counter
=
None
self
.
_finalize_allocation_log
(
size
)
kv_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, KV size:
{
kv_size
/
GB
:.
2
f
}
GB"
)
self
.
mem_usage
=
kv_size
/
GB
def
get_kv_size_bytes
(
self
):
def
get_kv_size_bytes
(
self
):
assert
hasattr
(
self
,
"k_buffer"
)
assert
hasattr
(
self
,
"k_buffer"
)
...
...
test/srt/test_swa_unittest.py
View file @
273b2834
...
@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase):
...
@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase):
i
for
i
in
range
(
num_layers
)
if
i
not
in
full_attention_layer_ids_set
i
for
i
in
range
(
num_layers
)
if
i
not
in
full_attention_layer_ids_set
]
]
pool
=
SWAKVPool
(
pool
=
SWAKVPool
(
size
,
size
=
size
,
size_swa
,
size_swa
=
size_swa
,
dtype
,
dtype
=
dtype
,
num_head
,
num_head
=
num_head
,
head_dim
,
head_dim
=
head_dim
,
swa_attention_layer_ids
,
swa_attention_layer_ids
=
swa_attention_layer_ids
,
full_attention_layer_ids
,
full_attention_layer_ids
=
full_attention_layer_ids
,
device
,
device
=
device
,
)
)
alloc
=
SWATokenToKVPoolAllocator
(
size
,
size_swa
,
dtype
,
device
,
pool
)
alloc
=
SWATokenToKVPoolAllocator
(
size
=
size
,
size_swa
=
size_swa
,
dtype
=
dtype
,
device
=
device
,
kvcache
=
pool
)
assert
alloc
.
available_size
()
==
size
+
size_swa
assert
alloc
.
available_size
()
==
size
+
size_swa
index
=
alloc
.
alloc
(
1
)
index
=
alloc
.
alloc
(
1
)
assert
alloc
.
available_size
()
==
size_swa
+
size_swa
-
2
assert
alloc
.
available_size
()
==
size_swa
+
size_swa
-
2
...
@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase):
...
@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase):
)
)
# setup kv pool
# setup kv pool
kv_pool
=
SWAKVPool
(
kv_pool
=
SWAKVPool
(
kv_size
,
size
=
kv_size
,
kv_size_swa
,
size_swa
=
kv_size_swa
,
dtype
,
dtype
=
dtype
,
num_head
,
num_head
=
num_head
,
head_dim
,
head_dim
=
head_dim
,
swa_attention_layer_ids
,
swa_attention_layer_ids
=
swa_attention_layer_ids
,
full_attention_layer_ids
,
full_attention_layer_ids
=
full_attention_layer_ids
,
device
,
device
=
device
,
)
)
# setup token to kv pool allocator
# setup token to kv pool allocator
allocator
=
SWATokenToKVPoolAllocator
(
allocator
=
SWATokenToKVPoolAllocator
(
kv_size
,
kv_size_swa
,
dtype
,
device
,
kv_pool
size
=
kv_size
,
size_swa
=
kv_size_swa
,
dtype
=
dtype
,
device
=
device
,
kvcache
=
kv_pool
,
)
)
# setup radix cache
# setup radix cache
tree
=
SWARadixCache
(
tree
=
SWARadixCache
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment