Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d2cb3024
Unverified
Commit
d2cb3024
authored
May 10, 2025
by
huangtingwei
Committed by
GitHub
May 09, 2025
Browse files
fix bug that gpu0 occupies more memory when hicache is turned on (#5778)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
1940cdec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
119 deletions
+115
-119
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+115
-119
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
d2cb3024
...
...
@@ -268,98 +268,97 @@ class HiCacheController:
"""
Directly write through KV caches to host memory without buffering.
"""
with
torch
.
cuda
.
stream
(
self
.
write_stream
)
:
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
self
.
mem_pool_host
.
write_page_all_layers
(
operation
.
host_indices
,
operation
.
device_indices
,
self
.
mem_pool_device
,
)
self
.
write_stream
.
synchronize
()
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_write_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
torch
.
cuda
.
set_
stream
(
self
.
write_stream
)
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
self
.
mem_pool_host
.
write_page_all_layers
(
operation
.
host_indices
,
operation
.
device_indices
,
self
.
mem_pool_device
,
)
self
.
write_stream
.
synchronize
()
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_write_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_direct
(
self
):
"""
Directly load KV caches from host memory to device memory without buffering.
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
# time.sleep(18e-6 * len(operation.host_indices))
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
)
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
# time.sleep(18e-6 * len(operation.host_indices))
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
)
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_layer_by_layer
(
self
):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
)
:
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
torch
.
cuda
.
set_
stream
(
self
.
load_stream
)
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
continue
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
if
batch_operation
is
None
:
continue
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
if
self
.
page_size
==
1
:
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
else
:
self
.
mem_pool_host
.
load_page_per_layer
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
,
self
.
mem_pool_device
,
i
,
)
self
.
load_stream
.
synchronize
()
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
for
node_id
in
batch_operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
if
self
.
page_size
==
1
:
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
else
:
self
.
mem_pool_host
.
load_page_per_layer
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
,
self
.
mem_pool_device
,
i
,
)
self
.
load_stream
.
synchronize
()
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
for
node_id
in
batch_operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
def
write_aux_func
(
self
,
no_wait
=
False
):
"""
Auxiliary function to prepare the buffer for write operations.
"""
torch
.
cuda
.
set_stream
(
self
.
write_stream
)
def
_to_op
(
op_
):
assert
op_
.
device_indices
.
is_cuda
,
"Device indices should be on GPU"
...
...
@@ -370,44 +369,42 @@ class HiCacheController:
return
op_
buffer
=
None
with
torch
.
cuda
.
stream
(
self
.
write_stream
):
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
factor
=
(
len
(
operation
.
device_indices
)
//
self
.
write_buffer
.
max_buffer_size
)
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
factor
=
(
len
(
operation
.
device_indices
)
//
self
.
write_buffer
.
max_buffer_size
)
if
factor
>=
1
:
if
buffer
is
not
None
:
_to_op
(
buffer
)
buffer
=
None
if
factor
<
2
:
_to_op
(
operation
)
else
:
split_ops
=
operation
.
split
(
factor
)
for
op_
in
split_ops
:
_to_op
(
op_
)
continue
if
buffer
is
None
:
buffer
=
operation
else
:
buffer
.
merge
(
operation
)
if
(
no_wait
or
len
(
buffer
.
host_indices
)
>=
self
.
write_buffer
.
max_buffer_size
or
self
.
write_queue
.
empty
()
or
self
.
write_buffer
.
empty
()
):
if
factor
>=
1
:
if
buffer
is
not
None
:
_to_op
(
buffer
)
buffer
=
None
except
Empty
:
if
factor
<
2
:
_to_op
(
operation
)
else
:
split_ops
=
operation
.
split
(
factor
)
for
op_
in
split_ops
:
_to_op
(
op_
)
continue
except
Exception
as
e
:
logger
.
error
(
e
)
if
buffer
is
None
:
buffer
=
operation
else
:
buffer
.
merge
(
operation
)
if
(
no_wait
or
len
(
buffer
.
host_indices
)
>=
self
.
write_buffer
.
max_buffer_size
or
self
.
write_queue
.
empty
()
or
self
.
write_buffer
.
empty
()
):
_to_op
(
buffer
)
buffer
=
None
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_aux_func
(
self
):
"""
...
...
@@ -484,19 +481,18 @@ class HiCacheController:
aux_thread
.
join
()
def
load_thread_func_buffer
(
self
):
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
aux_thread
=
threading
.
Thread
(
target
=
self
.
load_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
not
self
.
stop_event
.
is_set
():
operation
=
self
.
load_buffer
.
get
()
if
operation
is
None
:
continue
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
while
not
self
.
stop_event
.
is_set
():
operation
=
self
.
load_buffer
.
get
()
if
operation
is
None
:
continue
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
aux_thread
.
join
()
def
evict_device
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment