Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
10b544ae
"docs/vscode:/vscode.git/clone" did not exist on "d64ff25011bccf0b3fc052f4dc8ba276f59e1e2e"
Unverified
Commit
10b544ae
authored
Mar 12, 2025
by
Zhiqiang Xie
Committed by
GitHub
Mar 12, 2025
Browse files
Hierarchical Caching Refactoring and Fixing TP issue (#4082)
parent
01090e8a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
194 additions
and
56 deletions
+194
-56
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+63
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+30
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-28
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+45
-17
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+21
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
10b544ae
...
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
self
.
counter
=
num_layers
self
.
condition
=
threading
.
Condition
()
def
increment
(
self
):
with
self
.
condition
:
self
.
counter
+=
1
self
.
condition
.
notify_all
()
def
wait_until
(
self
,
threshold
):
with
self
.
condition
:
while
self
.
counter
<=
threshold
:
self
.
condition
.
wait
()
def
reset
(
self
):
with
self
.
condition
:
self
.
counter
=
0
class
CacheOperation
:
class
CacheOperation
:
counter
=
0
counter
=
0
...
@@ -132,6 +152,7 @@ class HiCacheController:
...
@@ -132,6 +152,7 @@ class HiCacheController:
self
,
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
mem_pool_host
:
MHATokenToKVPoolHost
,
mem_pool_host
:
MHATokenToKVPoolHost
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
):
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
...
@@ -139,6 +160,10 @@ class HiCacheController:
...
@@ -139,6 +160,10 @@ class HiCacheController:
self
.
mem_pool_host
=
mem_pool_host
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
write_policy
=
write_policy
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
if
write_policy
not
in
[
if
write_policy
not
in
[
"write_through"
,
"write_through"
,
"write_through_selective"
,
"write_through_selective"
,
...
@@ -165,7 +190,7 @@ class HiCacheController:
...
@@ -165,7 +190,7 @@ class HiCacheController:
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
)
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_
buff
er
,
daemon
=
True
target
=
self
.
load_thread_func_
layer_by_lay
er
,
daemon
=
True
)
)
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
self
.
load_thread
.
start
()
...
@@ -186,7 +211,7 @@ class HiCacheController:
...
@@ -186,7 +211,7 @@ class HiCacheController:
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
)
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_
buff
er
,
daemon
=
True
target
=
self
.
load_thread_func_
layer_by_lay
er
,
daemon
=
True
)
)
self
.
stop_event
.
clear
()
self
.
stop_event
.
clear
()
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
...
@@ -273,6 +298,42 @@ class HiCacheController:
...
@@ -273,6 +298,42 @@ class HiCacheController:
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
def
load_thread_func_layer_by_layer
(
self
):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
if
batch_operation
is
None
:
continue
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
for
node_id
in
batch_operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
def
write_aux_func
(
self
,
no_wait
=
False
):
def
write_aux_func
(
self
,
no_wait
=
False
):
"""
"""
Auxiliary function to prepare the buffer for write operations.
Auxiliary function to prepare the buffer for write operations.
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
10b544ae
...
@@ -315,6 +315,7 @@ class Req:
...
@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch
# The relative logprob_start_len in an extend batch
self
.
extend_logprob_start_len
=
0
self
.
extend_logprob_start_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
self
.
last_node_global
=
None
# Whether or not if it is chunked. It increments whenever
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
# it is chunked, and decrement whenever chunked request is
...
@@ -389,13 +390,24 @@ class Req:
...
@@ -389,13 +390,24 @@ class Req:
# Whether request reached finished condition
# Whether request reached finished condition
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
,
enable_hierarchical_cache
=
False
,
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
if
tree_cache
is
not
None
:
# tree cache is None if the prefix is not computed with tree cache.
# tree cache is None if the prefix is not computed with tree cache.
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
if
enable_hierarchical_cache
:
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
self
.
prefix_indices
,
self
.
last_node
,
self
.
last_node_global
=
(
)
tree_cache
.
match_prefix
(
key
=
self
.
adjust_max_prefix_ids
(),
include_evicted
=
True
)
)
else
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
10b544ae
...
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
...
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class
SchedulePolicy
:
class
SchedulePolicy
:
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
,
enable_hierarchical_cache
:
bool
=
False
,
):
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
enable_hierarchical_cache
=
enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching.
# It is used to find the matching prefix for in-batch prefix caching.
self
.
waiting_queue_radix_tree
=
RadixCache
(
self
.
waiting_queue_radix_tree
=
RadixCache
(
...
@@ -149,9 +155,14 @@ class SchedulePolicy:
...
@@ -149,9 +155,14 @@ class SchedulePolicy:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
if
self
.
enable_hierarchical_cache
:
rid
=
r
.
rid
,
key
=
prefix_ids
r
.
prefix_indices
,
r
.
last_node
,
r
.
last_node_global
=
(
)
self
.
tree_cache
.
match_prefix
(
key
=
prefix_ids
,
include_evicted
=
True
)
)
else
:
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# If there are more than 1 request that have small matching prefix from
...
@@ -428,7 +439,9 @@ class PrefillAdder:
...
@@ -428,7 +439,9 @@ class PrefillAdder:
return
self
.
budget_state
()
return
self
.
budget_state
()
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
):
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
,
enable_hierarchical_cache
:
bool
=
False
):
if
req
.
sampling_params
.
ignore_eos
and
self
.
tree_cache
.
disable
:
if
req
.
sampling_params
.
ignore_eos
and
self
.
tree_cache
.
disable
:
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
...
@@ -448,6 +461,18 @@ class PrefillAdder:
...
@@ -448,6 +461,18 @@ class PrefillAdder:
if
total_tokens
>
self
.
rem_total_tokens
:
if
total_tokens
>
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
return
AddReqResult
.
NO_TOKEN
if
(
enable_hierarchical_cache
and
req
.
last_node_global
is
not
None
and
req
.
last_node_global
.
evicted
):
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node_global
,
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
input_tokens
=
req
.
extend_input_len
prefix_len
=
len
(
req
.
prefix_indices
)
if
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
:
if
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
:
# Non-chunked prefill
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
10b544ae
...
@@ -265,12 +265,10 @@ class Scheduler:
...
@@ -265,12 +265,10 @@ class Scheduler:
f
"context_len=
{
self
.
model_config
.
context_len
}
"
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
)
# Init memory pool and cache
self
.
init_memory_pool_and_cache
()
self
.
init_memory_pool_and_cache
()
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
staging_reqs
=
{}
# The running decoding batch for continuous batching
# The running decoding batch for continuous batching
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
# The current forward batch
# The current forward batch
...
@@ -308,7 +306,9 @@ class Scheduler:
...
@@ -308,7 +306,9 @@ class Scheduler:
self
.
grammar_backend
=
None
self
.
grammar_backend
=
None
# Init schedule policy and new token estimation
# Init schedule policy and new token estimation
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
,
self
.
enable_hierarchical_cache
)
assert
(
assert
(
server_args
.
schedule_conservativeness
>=
0
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
),
"Invalid schedule_conservativeness"
...
@@ -431,6 +431,7 @@ class Scheduler:
...
@@ -431,6 +431,7 @@ class Scheduler:
self
.
tree_cache
=
HiRadixCache
(
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_cache_group
=
self
.
tp_worker
.
get_tp_cpu_group
(),
)
)
else
:
else
:
self
.
tree_cache
=
RadixCache
(
self
.
tree_cache
=
RadixCache
(
...
@@ -1005,6 +1006,11 @@ class Scheduler:
...
@@ -1005,6 +1006,11 @@ class Scheduler:
self
.
batch_is_full
=
True
self
.
batch_is_full
=
True
return
None
return
None
if
self
.
enable_hierarchical_cache
:
# check for completion of hierarchical cache activities to release memory
self
.
tree_cache
.
writing_check
()
self
.
tree_cache
.
loading_check
()
# Get priority queue
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
...
@@ -1048,32 +1054,14 @@ class Scheduler:
...
@@ -1048,32 +1054,14 @@ class Scheduler:
self
.
batch_is_full
=
True
self
.
batch_is_full
=
True
break
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
,
self
.
enable_hierarchical_cache
,
)
if
self
.
enable_hierarchical_cache
and
req
.
last_node
is
not
None
:
res
=
adder
.
add_one_req
(
if
req
.
last_node
.
evicted
:
req
,
self
.
chunked_req
,
self
.
enable_hierarchical_cache
# loading KV cache for the request
)
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node
,
req
.
prefix_indices
,
adder
.
rem_total_tokens
,
)
if
req
.
last_node
.
loading
:
# to prevent frequent cache invalidation
if
req
.
rid
in
self
.
staging_reqs
:
self
.
tree_cache
.
dec_lock_ref
(
self
.
staging_reqs
[
req
.
rid
])
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
staging_reqs
[
req
.
rid
]
=
req
.
last_node
continue
elif
req
.
last_node
.
loading
:
if
not
self
.
tree_cache
.
loading_complete
(
req
.
last_node
):
continue
if
req
.
rid
in
self
.
staging_reqs
:
self
.
tree_cache
.
dec_lock_ref
(
self
.
staging_reqs
[
req
.
rid
])
del
self
.
staging_reqs
[
req
.
rid
]
res
=
adder
.
add_one_req
(
req
,
self
.
chunked_req
)
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
==
AddReqResult
.
NO_TOKEN
:
if
res
==
AddReqResult
.
NO_TOKEN
:
if
self
.
enable_hierarchical_cache
:
if
self
.
enable_hierarchical_cache
:
...
@@ -1094,6 +1082,9 @@ class Scheduler:
...
@@ -1094,6 +1082,9 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
]
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
read_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
assert
self
.
chunked_req
is
None
self
.
chunked_req
=
adder
.
new_chunked_req
self
.
chunked_req
=
adder
.
new_chunked_req
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
10b544ae
import
heapq
import
heapq
import
logging
import
logging
import
threading
import
time
import
time
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPoolHost
,
MHATokenToKVPoolHost
,
ReqToTokenPool
,
ReqToTokenPool
,
...
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
...
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
):
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool_allocator
.
get_kvcache
()
token_to_kv_pool_allocator
.
get_kvcache
()
)
)
self
.
tp_group
=
tp_cache_group
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
load_cache_event
=
self
.
load_cache_event
,
)
)
# record the nodes with ongoing write through
# record the nodes with ongoing write through
...
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
...
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def
write_backup
(
self
,
node
:
TreeNode
):
def
write_backup
(
self
,
node
:
TreeNode
):
host_indices
=
self
.
cache_controller
.
write
(
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
device_indices
=
node
.
value
,
priority
=-
self
.
get_height
(
node
),
node_id
=
node
.
id
,
node_id
=
node
.
id
,
)
)
if
host_indices
is
None
:
if
host_indices
is
None
:
self
.
evict_host
(
len
(
node
.
value
))
self
.
evict_host
(
len
(
node
.
value
))
host_indices
=
self
.
cache_controller
.
write
(
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
device_indices
=
node
.
value
,
priority
=-
self
.
get_height
(
node
),
node_id
=
node
.
id
,
node_id
=
node
.
id
,
)
)
if
host_indices
is
not
None
:
if
host_indices
is
not
None
:
...
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
...
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node
.
hit_count
=
0
node
.
hit_count
=
0
def
writing_check
(
self
):
def
writing_check
(
self
):
while
not
self
.
cache_controller
.
ack_write_queue
.
empty
():
queue_size
=
torch
.
tensor
(
try
:
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get_nowait
()
)
self
.
dec_lock_ref
(
self
.
ongoing_write_through
[
ack_id
])
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# clear the reference
# synchrnoize TP workers to make the same update to radix cache
del
self
.
ongoing_write_through
[
ack_id
]
torch
.
distributed
.
all_reduce
(
except
Exception
:
queue_size
,
break
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
self
.
dec_lock_ref
(
self
.
ongoing_write_through
[
ack_id
])
del
self
.
ongoing_write_through
[
ack_id
]
def
loading_check
(
self
):
def
loading_check
(
self
):
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
...
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
...
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break
break
def
evictable_size
(
self
):
def
evictable_size
(
self
):
self
.
writing_check
()
self
.
loading_check
()
return
self
.
evictable_size_
return
self
.
evictable_size_
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
=
None
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
=
None
):
...
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
...
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return
device_indices
return
device_indices
def
loading_complete
(
self
,
node
:
TreeNode
):
self
.
loading_check
()
return
node
.
loading
==
False
def
init_load_back
(
def
init_load_back
(
self
,
self
,
last_node
:
TreeNode
,
last_node
:
TreeNode
,
...
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
...
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
return
last_node
,
prefix_indices
def
read_to_load_cache
(
self
):
self
.
load_cache_event
.
set
()
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
if
self
.
disable
:
return
[],
self
.
root_node
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
if
value
:
value
=
torch
.
concat
(
value
)
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
last_node_global
=
last_node
while
last_node
.
evicted
:
last_node
=
last_node
.
parent
if
include_evicted
:
return
value
,
last_node
,
last_node_global
else
:
return
value
,
last_node
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
value
=
[]
value
=
[]
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
10b544ae
...
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
...
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self
.
layer_num
=
layer_num
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
self
.
_create_buffers
()
self
.
layer_transfer_counter
=
None
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
...
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
...
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
self
.
k_buffer
[
layer_id
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
][
indices
]
=
v_data
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
]
return
self
.
v_buffer
[
layer_id
]
...
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
...
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def
get_flat_data
(
self
,
indices
):
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment