Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d2cb3024
Unverified
Commit
d2cb3024
authored
May 10, 2025
by
huangtingwei
Committed by
GitHub
May 09, 2025
Browse files
fix bug that gpu0 occupies more memory when hicache is turned on (#5778)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
1940cdec
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
119 deletions
+115
-119
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+115
-119
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
d2cb3024
...
@@ -268,7 +268,7 @@ class HiCacheController:
...
@@ -268,7 +268,7 @@ class HiCacheController:
"""
"""
Directly write through KV caches to host memory without buffering.
Directly write through KV caches to host memory without buffering.
"""
"""
with
torch
.
cuda
.
stream
(
self
.
write_stream
)
:
torch
.
cuda
.
set_
stream
(
self
.
write_stream
)
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
try
:
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
...
@@ -291,7 +291,7 @@ class HiCacheController:
...
@@ -291,7 +291,7 @@ class HiCacheController:
"""
"""
Directly load KV caches from host memory to device memory without buffering.
Directly load KV caches from host memory to device memory without buffering.
"""
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
)
:
torch
.
cuda
.
set_
stream
(
self
.
load_stream
)
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
try
:
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
...
@@ -299,9 +299,7 @@ class HiCacheController:
...
@@ -299,9 +299,7 @@ class HiCacheController:
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
operation
.
host_indices
)
)
self
.
mem_pool_device
.
transfer
(
self
.
mem_pool_device
.
transfer
(
operation
.
device_indices
,
operation
.
data
)
operation
.
device_indices
,
operation
.
data
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
if
node_id
!=
0
:
...
@@ -315,7 +313,7 @@ class HiCacheController:
...
@@ -315,7 +313,7 @@ class HiCacheController:
"""
"""
Load KV caches from host memory to device memory layer by layer.
Load KV caches from host memory to device memory layer by layer.
"""
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
)
:
torch
.
cuda
.
set_
stream
(
self
.
load_stream
)
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
if
not
self
.
load_cache_event
.
is_set
():
...
@@ -360,6 +358,7 @@ class HiCacheController:
...
@@ -360,6 +358,7 @@ class HiCacheController:
"""
"""
Auxiliary function to prepare the buffer for write operations.
Auxiliary function to prepare the buffer for write operations.
"""
"""
torch
.
cuda
.
set_stream
(
self
.
write_stream
)
def
_to_op
(
op_
):
def
_to_op
(
op_
):
assert
op_
.
device_indices
.
is_cuda
,
"Device indices should be on GPU"
assert
op_
.
device_indices
.
is_cuda
,
"Device indices should be on GPU"
...
@@ -370,13 +369,11 @@ class HiCacheController:
...
@@ -370,13 +369,11 @@ class HiCacheController:
return
op_
return
op_
buffer
=
None
buffer
=
None
with
torch
.
cuda
.
stream
(
self
.
write_stream
):
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
try
:
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
factor
=
(
factor
=
(
len
(
operation
.
device_indices
)
len
(
operation
.
device_indices
)
//
self
.
write_buffer
.
max_buffer_size
//
self
.
write_buffer
.
max_buffer_size
)
)
if
factor
>=
1
:
if
factor
>=
1
:
...
@@ -484,10 +481,9 @@ class HiCacheController:
...
@@ -484,10 +481,9 @@ class HiCacheController:
aux_thread
.
join
()
aux_thread
.
join
()
def
load_thread_func_buffer
(
self
):
def
load_thread_func_buffer
(
self
):
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
aux_thread
=
threading
.
Thread
(
target
=
self
.
load_aux_func
,
daemon
=
True
)
aux_thread
=
threading
.
Thread
(
target
=
self
.
load_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
aux_thread
.
start
()
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
operation
=
self
.
load_buffer
.
get
()
operation
=
self
.
load_buffer
.
get
()
if
operation
is
None
:
if
operation
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment