Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
51caee74
"test/vscode:/vscode.git/clone" did not exist on "dc0705a504fc423cbf38376eb864c898578f7c9a"
Unverified
Commit
51caee74
authored
Jan 07, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jan 07, 2025
Browse files
Host memory pool for hierarchical caching (#2771)
parent
58f9060e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
1 deletion
+230
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+206
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+24
-0
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
51caee74
...
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
...
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
"""
"""
import
logging
import
logging
import
threading
from
enum
import
IntEnum
from
functools
import
wraps
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
import
psutil
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.srt.utils
import
debug_timing
,
get_compiler_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
del
self
.
k_buffer
del
self
.
k_buffer
del
self
.
v_buffer
del
self
.
v_buffer
# Todo: different memory layout
def
get_flat_data
(
self
,
indices
):
# prepare a large chunk of contiguous data for efficient transfer
flatten
=
torch
.
stack
(
[
torch
.
stack
([
self
.
k_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)]),
torch
.
stack
([
self
.
v_buffer
[
i
][
indices
]
for
i
in
range
(
self
.
layer_num
)]),
]
)
return
flatten
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
for
i
in
range
(
self
.
layer_num
):
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
...
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
...
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
class
MemoryStateInt
(
IntEnum
):
IDLE
=
0
RESERVED
=
1
PROTECTED
=
2
SYNCED
=
3
BACKUP
=
4
def
synchronized
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
return
wrapper
class
MLATokenToKVPoolHost
:
def
__init__
(
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
=
2.0
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
):
assert
(
host_to_device_ratio
>=
1
),
"The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self
.
device_pool
=
device_pool
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
self
.
dtype
=
device_pool
.
store_dtype
self
.
head_num
=
device_pool
.
head_num
self
.
head_dim
=
device_pool
.
head_dim
self
.
layer_num
=
device_pool
.
layer_num
self
.
size_per_token
=
(
self
.
head_dim
*
self
.
head_num
*
self
.
layer_num
*
self
.
dtype
.
itemsize
*
2
)
# Verify there is enough available host memory.
host_mem
=
psutil
.
virtual_memory
()
requested_bytes
=
self
.
size
*
self
.
size_per_token
# preserve at least 10GB for other usage
ten_gb
=
10
*
(
1024
**
3
)
if
requested_bytes
>
host_mem
.
available
-
ten_gb
:
raise
ValueError
(
f
"Not enough host memory available. Requesting "
f
"
{
requested_bytes
/
1e9
:.
2
f
}
GB but only have "
f
"
{
host_mem
.
available
/
1e9
:.
2
f
}
GB free. Please reduce the "
f
"size of the hierarchical cache."
)
else
:
logger
.
info
(
f
"Allocating
{
requested_bytes
/
1e9
:.
2
f
}
GB host memory for hierarchical KV cache."
)
self
.
kv_buffer
=
torch
.
empty
(
(
2
,
self
.
layer_num
,
self
.
size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
(
self
.
size
,),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int32
)
self
.
can_use_mem_size
=
self
.
size
# A lock for synchronized operations on memory allocation and state transitions.
self
.
lock
=
threading
.
RLock
()
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
@
debug_timing
def
transfer
(
self
,
indices
,
flat_data
):
# backup prepared data from device to host
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
@
synchronized
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
can_use_mem_size
=
self
.
size
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int32
)
@
synchronized
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
@
synchronized
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
if
need_size
>
self
.
can_use_mem_size
:
return
None
# todo: de-fragementation
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
self
.
can_use_mem_size
-=
need_size
return
select_index
@
synchronized
def
is_reserved
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
RESERVED
@
synchronized
def
is_protected
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
PROTECTED
@
synchronized
def
is_synced
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
SYNCED
@
synchronized
def
is_backup
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
BACKUP
@
synchronized
def
update_backup
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_synced
(
indices
),
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
def
update_synced
(
self
,
indices
:
torch
.
Tensor
):
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
@
synchronized
def
protect_write
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_reserved
(
indices
),
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
def
protect_load
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_backup
(
indices
),
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
def
complete_io
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_protected
(
indices
),
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
@
synchronized
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
self
.
free_slots
=
torch
.
concat
([
self
.
free_slots
,
indices
])
self
.
can_use_mem_size
+=
len
(
indices
)
return
len
(
indices
)
python/sglang/srt/utils.py
View file @
51caee74
...
@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer:
...
@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer:
@
staticmethod
@
staticmethod
def
deserialize
(
data
):
def
deserialize
(
data
):
return
ForkingPickler
.
loads
(
data
)
return
ForkingPickler
.
loads
(
data
)
def
debug_timing
(
func
):
# todo: replace with a more organized instrumentation
def
wrapper
(
*
args
,
**
kwargs
):
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
tic
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
toc
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
tic
.
record
()
result
=
func
(
*
args
,
**
kwargs
)
toc
.
record
()
torch
.
cuda
.
synchronize
()
# Ensure all CUDA operations are complete
elapsed
=
tic
.
elapsed_time
(
toc
)
indices
=
kwargs
.
get
(
"indices"
,
args
[
1
]
if
len
(
args
)
>
1
else
None
)
num_tokens
=
len
(
indices
)
if
indices
is
not
None
else
0
throughput
=
num_tokens
/
elapsed
*
1000
if
elapsed
>
0
else
0
logger
.
debug
(
f
"Transfer time:
{
elapsed
}
ms, throughput:
{
throughput
}
tokens/s"
)
return
result
else
:
return
func
(
*
args
,
**
kwargs
)
return
wrapper
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment