Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
15521495
Unverified
Commit
15521495
authored
May 19, 2025
by
wangxiyu191
Committed by
GitHub
May 18, 2025
Browse files
refactor: Extract repeated member variables in KVCache subclasses to base class. (#6323)
parent
ebe58d54
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
47 deletions
+63
-47
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+63
-47
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
15521495
...
@@ -94,6 +94,33 @@ class ReqToTokenPool:
...
@@ -94,6 +94,33 @@ class ReqToTokenPool:
class
KVCache
(
abc
.
ABC
):
class
KVCache
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
__init__
(
self
,
size
:
int
,
page_size
:
int
,
dtype
:
torch
.
dtype
,
layer_num
:
int
,
device
:
str
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
self
.
size
=
size
self
.
page_size
=
page_size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
layer_num
=
layer_num
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
...
@@ -217,25 +244,20 @@ class MHATokenToKVPool(KVCache):
...
@@ -217,25 +244,20 @@ class MHATokenToKVPool(KVCache):
start_layer
:
Optional
[
int
]
=
None
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
self
.
size
=
size
super
().
__init__
(
self
.
page_size
=
page_size
size
,
self
.
dtype
=
dtype
page_size
,
self
.
device
=
device
dtype
,
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
layer_num
,
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
device
,
self
.
store_dtype
=
torch
.
uint8
enable_memory_saver
,
else
:
start_layer
,
self
.
store_dtype
=
dtype
end_layer
,
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
)
self
.
head_num
=
head_num
self
.
head_num
=
head_num
self
.
head_dim
=
head_dim
self
.
head_dim
=
head_dim
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
self
.
_create_buffers
()
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
self
.
layer_transfer_counter
=
None
self
.
layer_transfer_counter
=
None
self
.
capture_mode
=
False
self
.
capture_mode
=
False
...
@@ -493,26 +515,21 @@ class MLATokenToKVPool(KVCache):
...
@@ -493,26 +515,21 @@ class MLATokenToKVPool(KVCache):
start_layer
:
Optional
[
int
]
=
None
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
self
.
size
=
size
super
().
__init__
(
self
.
page_size
=
page_size
size
,
self
.
dtype
=
dtype
page_size
,
self
.
device
=
device
dtype
,
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
layer_num
,
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
device
,
self
.
store_dtype
=
torch
.
uint8
enable_memory_saver
,
else
:
start_layer
,
self
.
store_dtype
=
dtype
end_layer
,
)
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
layer_num
=
layer_num
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
with
memory_saver_adapter
.
region
():
with
self
.
memory_saver_adapter
.
region
():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
kv_buffer
=
[
self
.
kv_buffer
=
[
torch
.
zeros
(
torch
.
zeros
(
...
@@ -636,20 +653,18 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -636,20 +653,18 @@ class DoubleSparseTokenToKVPool(KVCache):
start_layer
:
Optional
[
int
]
=
None
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
self
.
size
=
size
super
().
__init__
(
self
.
page_size
=
page_size
size
,
self
.
dtype
=
dtype
page_size
,
self
.
device
=
device
dtype
,
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
layer_num
,
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
device
,
self
.
store_dtype
=
torch
.
uint8
enable_memory_saver
,
else
:
start_layer
,
self
.
store_dtype
=
dtype
end_layer
,
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
)
with
memory_saver_adapter
.
region
():
with
self
.
memory_saver_adapter
.
region
():
# [size, head_num, head_dim] for each layer
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
self
.
k_buffer
=
[
torch
.
zeros
(
torch
.
zeros
(
...
@@ -672,9 +687,6 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -672,9 +687,6 @@ class DoubleSparseTokenToKVPool(KVCache):
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
...
@@ -742,7 +754,7 @@ class HostKVCache(abc.ABC):
...
@@ -742,7 +754,7 @@ class HostKVCache(abc.ABC):
def
__init__
(
def
__init__
(
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
KVCache
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
host_size
:
int
,
pin_memory
:
bool
,
pin_memory
:
bool
,
...
@@ -914,6 +926,8 @@ class HostKVCache(abc.ABC):
...
@@ -914,6 +926,8 @@ class HostKVCache(abc.ABC):
class
MHATokenToKVPoolHost
(
HostKVCache
):
class
MHATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MHATokenToKVPool
def
__init__
(
def
__init__
(
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
...
@@ -997,6 +1011,8 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -997,6 +1011,8 @@ class MHATokenToKVPoolHost(HostKVCache):
class
MLATokenToKVPoolHost
(
HostKVCache
):
class
MLATokenToKVPoolHost
(
HostKVCache
):
device_pool
:
MLATokenToKVPool
def
__init__
(
def
__init__
(
self
,
self
,
device_pool
:
MLATokenToKVPool
,
device_pool
:
MLATokenToKVPool
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment