Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5d6e9467
Unverified
Commit
5d6e9467
authored
Jan 10, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jan 10, 2025
Browse files
Cache controller for hierarchical caching (#2804)
parent
a47bf391
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
307 additions
and
0 deletions
+307
-0
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+307
-0
No files found.
python/sglang/srt/managers/cache_controller.py
0 → 100644
View file @
5d6e9467
from
__future__
import
annotations
"""
Copyright 2023-2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
logging
import
threading
from
queue
import
PriorityQueue
,
Queue
from
typing
import
Optional
import
torch
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
MLATokenToKVPoolHost
logger
=
logging
.
getLogger
(
__name__
)
class
CacheOperation
:
counter
=
0
def
__init__
(
self
,
host_indices
:
torch
.
Tensor
,
device_indices
:
torch
.
Tensor
,
node_id
:
int
,
priority
:
Optional
[
int
]
=
None
,
):
self
.
host_indices
=
host_indices
self
.
device_indices
=
device_indices
self
.
node_ids
=
[
node_id
]
self
.
data
=
None
self
.
id
=
CacheOperation
.
counter
CacheOperation
.
counter
+=
1
# default priority is the order of creation
self
.
priority
=
priority
if
priority
is
not
None
else
self
.
id
def
merge
(
self
,
other
:
"CacheOperation"
)
->
None
:
# multiple operations can be merged into a single operation for batch processing
self
.
host_indices
=
torch
.
cat
([
self
.
host_indices
,
other
.
host_indices
])
self
.
device_indices
=
torch
.
cat
([
self
.
device_indices
,
other
.
device_indices
])
self
.
priority
=
min
(
self
.
priority
,
other
.
priority
)
self
.
node_ids
.
extend
(
other
.
node_ids
)
def
__lt__
(
self
,
other
:
"CacheOperation"
):
return
self
.
priority
<
other
.
priority
class
TransferBuffer
:
"""
Overlapping buffer preparation and transfer operations to improve throughput.
"""
def
__init__
(
self
,
buffer_count
:
int
=
3
,
max_buffer_size
:
int
=
1000
)
->
None
:
self
.
buffers
=
Queue
(
maxsize
=
buffer_count
)
# todo: adjust the buffer size based on throughput profile of the system
self
.
max_buffer_size
=
max_buffer_size
def
full
(
self
)
->
bool
:
return
self
.
buffers
.
full
()
def
empty
(
self
)
->
bool
:
return
self
.
buffers
.
empty
()
def
put
(
self
,
item
,
block
=
True
)
->
None
:
self
.
buffers
.
put
(
item
,
block
=
block
)
def
get
(
self
,
block
=
True
)
->
Optional
[
CacheOperation
]:
try
:
return
self
.
buffers
.
get
(
block
=
block
)
except
Exception
as
e
:
logger
.
error
(
e
)
class
HiCacheController
:
def
__init__
(
self
,
mem_pool_device
:
MHATokenToKVPool
,
mem_pool_host
:
MLATokenToKVPoolHost
,
write_policy
:
str
=
"write_through_selective"
,
):
self
.
mem_pool_device
=
mem_pool_device
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
if
write_policy
not
in
[
"write_through"
,
"write_through_selective"
,
"write_back"
,
]:
raise
ValueError
(
f
"Invalid write policy:
{
write_policy
}
"
)
self
.
write_queue
=
PriorityQueue
()
self
.
load_queue
=
PriorityQueue
()
self
.
ack_write_queue
=
Queue
()
self
.
ack_load_queue
=
Queue
()
self
.
write_buffer
=
TransferBuffer
()
self
.
load_buffer
=
TransferBuffer
()
self
.
write_stream
=
torch
.
cuda
.
Stream
()
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_buffer
,
daemon
=
True
)
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
def
write
(
self
,
device_indices
:
torch
.
Tensor
,
priority
:
Optional
[
int
]
=
None
,
node_id
:
int
=
0
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Back up KV caches from device memory to host memory.
"""
host_indices
=
self
.
mem_pool_host
.
alloc
(
len
(
device_indices
))
if
host_indices
is
None
:
return
None
self
.
write_queue
.
put
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
self
.
mem_pool_host
.
protect_write
(
host_indices
)
return
host_indices
def
load
(
self
,
host_indices
:
torch
.
Tensor
,
priority
:
Optional
[
int
]
=
None
,
node_id
:
int
=
0
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Load KV caches from host memory to device memory.
"""
device_indices
=
self
.
mem_pool_device
.
alloc
(
len
(
host_indices
))
if
device_indices
is
None
:
return
None
self
.
load_queue
.
put
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
self
.
mem_pool_host
.
protect_load
(
host_indices
)
return
device_indices
def
write_thread_func_direct
(
self
):
"""
Directly write through KV caches to host memory without buffering.
"""
with
torch
.
cuda
.
stream
(
self
.
write_stream
):
while
True
:
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
)
operation
.
data
=
self
.
mem_pool_device
.
get_flat_data
(
operation
.
device_indices
)
self
.
mem_pool_host
.
transfer
(
operation
.
host_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
self
.
ack_write_queue
.
put
(
node_id
)
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_direct
(
self
):
"""
Directly load KV caches from host memory to device memory without buffering.
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
True
:
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
)
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
)
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
self
.
ack_load_queue
.
put
(
node_id
)
except
Exception
as
e
:
logger
.
error
(
e
)
def
write_aux_func
(
self
,
no_wait
=
False
):
"""
Auxiliary function to prepare the buffer for write operations.
"""
buffer
=
None
while
True
:
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
)
if
buffer
is
None
:
buffer
=
operation
else
:
buffer
.
merge
(
operation
)
if
(
no_wait
or
len
(
buffer
.
host_indices
)
>=
self
.
write_buffer
.
max_buffer_size
or
self
.
write_queue
.
empty
()
or
self
.
write_buffer
.
empty
()
):
assert
(
buffer
.
device_indices
.
is_cuda
),
"Device indices should be on GPU"
buffer
.
data
=
self
.
mem_pool_device
.
get_flat_data
(
buffer
.
device_indices
).
contiguous
()
self
.
write_buffer
.
put
(
buffer
,
block
=
True
)
buffer
=
None
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_aux_func
(
self
):
"""
Auxiliary function to prepare the buffer for load operations.
"""
buffer
=
None
while
True
:
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
)
if
buffer
is
None
:
buffer
=
operation
else
:
buffer
.
merge
(
operation
)
if
(
len
(
buffer
.
host_indices
)
>=
self
.
load_buffer
.
max_buffer_size
or
self
.
load_queue
.
empty
()
or
self
.
load_buffer
.
empty
()
):
buffer
.
data
=
(
self
.
mem_pool_host
.
get_flat_data
(
buffer
.
host_indices
)
.
contiguous
()
.
pin_memory
()
)
self
.
load_buffer
.
put
(
buffer
,
block
=
True
)
buffer
=
None
except
Exception
as
e
:
logger
.
error
(
e
)
def
write_thread_func_buffer
(
self
):
aux_thread
=
threading
.
Thread
(
target
=
self
.
write_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
with
torch
.
cuda
.
stream
(
self
.
write_stream
):
while
True
:
operation
=
self
.
write_buffer
.
get
()
if
operation
is
None
:
continue
self
.
mem_pool_host
.
transfer
(
operation
.
host_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
self
.
ack_write_queue
.
put
(
node_id
)
def
load_thread_func_buffer
(
self
):
aux_thread
=
threading
.
Thread
(
target
=
self
.
load_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
True
:
operation
=
self
.
load_buffer
.
get
()
if
operation
is
None
:
continue
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
self
.
ack_load_queue
.
put
(
node_id
)
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
)
->
int
:
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
self
.
mem_pool_device
.
free
(
device_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
return
len
(
device_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
def
evict_host
(
self
,
host_indices
:
torch
.
Tensor
,
backup_only
:
bool
=
True
)
->
int
:
if
not
backup_only
:
raise
ValueError
(
"Other eviction policies are not supported yet."
)
if
self
.
mem_pool_host
.
is_backup
(
host_indices
):
self
.
mem_pool_host
.
free
(
host_indices
)
return
len
(
host_indices
)
else
:
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment