Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
51caee74
Unverified
Commit
51caee74
authored
Jan 07, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jan 07, 2025
Browse files
Host memory pool for hierarchical caching (#2771)
parent
58f9060e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
1 deletion
+230
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+206
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+24
-0
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
51caee74
...
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
...
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
"""
"""
import
logging
import
logging
import
threading
from
enum
import
IntEnum
from
functools
import
wraps
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
import
psutil
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.srt.utils
import
debug_timing
,
get_compiler_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
del
self
.
k_buffer
del
self
.
k_buffer
del
self
.
v_buffer
del
self
.
v_buffer
# Todo: different memory layout
def
get_flat_data
(
self
,
indices
):
# prepare a large chunk of contiguous data for efficient transfer
flatten
=
torch
.
stack
(
[
torch
.
stack
([
self
.
k_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)]),
torch
.
stack
([
self
.
v_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)]),
]
)
return
flatten
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
for
i
in
range
(
self
.
layer_num
):
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
...
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
...
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
class
MemoryStateInt
(
IntEnum
):
IDLE
=
0
RESERVED
=
1
PROTECTED
=
2
SYNCED
=
3
BACKUP
=
4
def
synchronized
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
return
wrapper
class
MLATokenToKVPoolHost
:
def
__init__
(
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
=
2.0
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
):
assert
(
host_to_device_ratio
>=
1
),
"The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self
.
device_pool
=
device_pool
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
self
.
dtype
=
device_pool
.
store_dtype
self
.
head_num
=
device_pool
.
head_num
self
.
head_dim
=
device_pool
.
head_dim
self
.
layer_num
=
device_pool
.
layer_num
self
.
size_per_token
=
(
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
)
# Verify there is enough available host memory.
host_mem
=
psutil
.
virtual_memory
()
requested_bytes
=
self
.
size
*
self
.
size_per_token
# preserve at least 10GB for other usage
ten_gb
=
10
*
(
1024
**
3
)
if
requested_bytes
>
host_mem
.
available
-
ten_gb
:
raise
ValueError
(
f
"Not enough host memory available. Requesting "
f
"
{
requested_bytes
/
1e9
:.
2
f
}
GB but only have "
f
"
{
host_mem
.
available
/
1e9
:.
2
f
}
GB free. Please reduce the "
f
"size of the hierarchical cache."
)
else
:
logger
.
info
(
f
"Allocating
{
requested_bytes
/
1e9
:.
2
f
}
GB host memory for hierarchical KV cache."
)
self
.
kv_buffer
=
torch
.
empty
(
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
(
self
.
size
,),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int32
)
self
.
can_use_mem_size
=
self
.
size
# A lock for synchronized operations on memory allocation and state transitions.
self
.
lock
=
threading
.
RLock
()
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
@
synchronized
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
can_use_mem_size
=
self
.
size
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int32
)
@
synchronized
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
@
synchronized
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
if
need_size
>
self
.
can_use_mem_size
:
return
None
# todo: de-fragementation
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
self
.
can_use_mem_size
-=
need_size
return
select_index
@
synchronized
def
is_reserved
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
RESERVED
@
synchronized
def
is_protected
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
PROTECTED
@
synchronized
def
is_synced
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
SYNCED
@
synchronized
def
is_backup
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
BACKUP
@
synchronized
def
update_backup
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_synced
(
indices
),
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
def
update_synced
(
self
,
indices
:
torch
.
Tensor
):
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
@
synchronized
def
protect_write
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_reserved
(
indices
),
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
def
protect_load
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_backup
(
indices
),
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
def
complete_io
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_protected
(
indices
),
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
@
synchronized
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
self
.
free_slots
=
torch
.
concat
([
self
.
free_slots
,
indices
])
self
.
can_use_mem_size
+=
len
(
indices
)
return
len
(
indices
)
python/sglang/srt/utils.py
View file @
51caee74
...
@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer:
...
@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer:
@
staticmethod
@
staticmethod
def
deserialize
(
data
):
def
deserialize
(
data
):
return
ForkingPickler
.
loads
(
data
)
return
ForkingPickler
.
loads
(
data
)
def
debug_timing
(
func
):
# todo: replace with a more organized instrumentation
def
wrapper
(
*
args
,
**
kwargs
):
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
tic
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
toc
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
tic
.
record
()
result
=
func
(
*
args
,
**
kwargs
)
toc
.
record
()
torch
.
cuda
.
synchronize
()
# Ensure all CUDA operations are complete
elapsed
=
tic
.
elapsed_time
(
toc
)
indices
=
kwargs
.
get
(
"indices"
,
args
[
1
]
if
len
(
args
)
>
1
else
None
)
num_tokens
=
len
(
indices
)
if
indices
is
not
None
else
0
throughput
=
num_tokens
/
elapsed
*
1000
if
elapsed
>
0
else
0
logger
.
debug
(
f
"Transfer time:
{
elapsed
}
ms, throughput:
{
throughput
}
tokens/s"
)
return
result
else
:
return
func
(
*
args
,
**
kwargs
)
return
wrapper
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment