Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
10b544ae
Unverified
Commit
10b544ae
authored
Mar 12, 2025
by
Zhiqiang Xie
Committed by
GitHub
Mar 12, 2025
Browse files
Hierarchical Caching Refactoring and Fixing TP issue (#4082)
parent
01090e8a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
194 additions
and
56 deletions
+194
-56
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+63
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+30
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-28
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+45
-17
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+21
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
10b544ae
...
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
self
.
counter
=
num_layers
self
.
condition
=
threading
.
Condition
()
def
increment
(
self
):
with
self
.
condition
:
self
.
counter
+=
1
self
.
condition
.
notify_all
()
def
wait_until
(
self
,
threshold
):
with
self
.
condition
:
while
self
.
counter
<=
threshold
:
self
.
condition
.
wait
()
def
reset
(
self
):
with
self
.
condition
:
self
.
counter
=
0
class
CacheOperation
:
class
CacheOperation
:
counter
=
0
counter
=
0
...
@@ -132,6 +152,7 @@ class HiCacheController:
...
@@ -132,6 +152,7 @@ class HiCacheController:
self
,
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
mem_pool_host
:
MHATokenToKVPoolHost
,
mem_pool_host
:
MHATokenToKVPoolHost
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
):
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
...
@@ -139,6 +160,10 @@ class HiCacheController:
...
@@ -139,6 +160,10 @@ class HiCacheController:
self
.
mem_pool_host
=
mem_pool_host
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
write_policy
=
write_policy
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
if
write_policy
not
in
[
if
write_policy
not
in
[
"write_through"
,
"write_through"
,
"write_through_selective"
,
"write_through_selective"
,
...
@@ -165,7 +190,7 @@ class HiCacheController:
...
@@ -165,7 +190,7 @@ class HiCacheController:
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
)
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_
buff
er
,
daemon
=
True
target
=
self
.
load_thread_func_
layer_by_lay
er
,
daemon
=
True
)
)
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
self
.
load_thread
.
start
()
...
@@ -186,7 +211,7 @@ class HiCacheController:
...
@@ -186,7 +211,7 @@ class HiCacheController:
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
)
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_
buff
er
,
daemon
=
True
target
=
self
.
load_thread_func_
layer_by_lay
er
,
daemon
=
True
)
)
self
.
stop_event
.
clear
()
self
.
stop_event
.
clear
()
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
...
@@ -273,6 +298,42 @@ class HiCacheController:
...
@@ -273,6 +298,42 @@ class HiCacheController:
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
def
load_thread_func_layer_by_layer
(
self
):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
if
batch_operation
is
None
:
continue
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
for
node_id
in
batch_operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
def
write_aux_func
(
self
,
no_wait
=
False
):
def
write_aux_func
(
self
,
no_wait
=
False
):
"""
"""
Auxiliary function to prepare the buffer for write operations.
Auxiliary function to prepare the buffer for write operations.
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
10b544ae
...
@@ -315,6 +315,7 @@ class Req:
...
@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch
# The relative logprob_start_len in an extend batch
self
.
extend_logprob_start_len
=
0
self
.
extend_logprob_start_len
=
0
self
.
last_node
=
None
self
.
last_node
=
None
self
.
last_node_global
=
None
# Whether or not if it is chunked. It increments whenever
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
# it is chunked, and decrement whenever chunked request is
...
@@ -389,13 +390,24 @@ class Req:
...
@@ -389,13 +390,24 @@ class Req:
# Whether request reached finished condition
# Whether request reached finished condition
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
,
enable_hierarchical_cache
=
False
,
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
if
tree_cache
is
not
None
:
# tree cache is None if the prefix is not computed with tree cache.
# tree cache is None if the prefix is not computed with tree cache.
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
if
enable_hierarchical_cache
:
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
self
.
prefix_indices
,
self
.
last_node
,
self
.
last_node_global
=
(
)
tree_cache
.
match_prefix
(
key
=
self
.
adjust_max_prefix_ids
(),
include_evicted
=
True
)
)
else
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
10b544ae
...
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
...
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class
SchedulePolicy
:
class
SchedulePolicy
:
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
,
enable_hierarchical_cache
:
bool
=
False
,
):
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
enable_hierarchical_cache
=
enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching.
# It is used to find the matching prefix for in-batch prefix caching.
self
.
waiting_queue_radix_tree
=
RadixCache
(
self
.
waiting_queue_radix_tree
=
RadixCache
(
...
@@ -149,9 +155,14 @@ class SchedulePolicy:
...
@@ -149,9 +155,14 @@ class SchedulePolicy:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
if
self
.
enable_hierarchical_cache
:
rid
=
r
.
rid
,
key
=
prefix_ids
r
.
prefix_indices
,
r
.
last_node
,
r
.
last_node_global
=
(
)
self
.
tree_cache
.
match_prefix
(
key
=
prefix_ids
,
include_evicted
=
True
)
)
else
:
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# If there are more than 1 request that have small matching prefix from
...
@@ -428,7 +439,9 @@ class PrefillAdder:
...
@@ -428,7 +439,9 @@ class PrefillAdder:
return
self
.
budget_state
()
return
self
.
budget_state
()
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
):
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
,
enable_hierarchical_cache
:
bool
=
False
):
if
req
.
sampling_params
.
ignore_eos
and
self
.
tree_cache
.
disable
:
if
req
.
sampling_params
.
ignore_eos
and
self
.
tree_cache
.
disable
:
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
...
@@ -448,6 +461,18 @@ class PrefillAdder:
...
@@ -448,6 +461,18 @@ class PrefillAdder:
if
total_tokens
>
self
.
rem_total_tokens
:
if
total_tokens
>
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
return
AddReqResult
.
NO_TOKEN
if
(
enable_hierarchical_cache
and
req
.
last_node_global
is
not
None
and
req
.
last_node_global
.
evicted
):
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node_global
,
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
input_tokens
=
req
.
extend_input_len
prefix_len
=
len
(
req
.
prefix_indices
)
if
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
:
if
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
:
# Non-chunked prefill
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
10b544ae
...
@@ -265,12 +265,10 @@ class Scheduler:
...
@@ -265,12 +265,10 @@ class Scheduler:
f
"context_len=
{
self
.
model_config
.
context_len
}
"
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
)
# Init memory pool and cache
self
.
init_memory_pool_and_cache
()
self
.
init_memory_pool_and_cache
()
# Init running status
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
staging_reqs
=
{}
# The running decoding batch for continuous batching
# The running decoding batch for continuous batching
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
# The current forward batch
# The current forward batch
...
@@ -308,7 +306,9 @@ class Scheduler:
...
@@ -308,7 +306,9 @@ class Scheduler:
self
.
grammar_backend
=
None
self
.
grammar_backend
=
None
# Init schedule policy and new token estimation
# Init schedule policy and new token estimation
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
,
self
.
enable_hierarchical_cache
)
assert
(
assert
(
server_args
.
schedule_conservativeness
>=
0
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
),
"Invalid schedule_conservativeness"
...
@@ -431,6 +431,7 @@ class Scheduler:
...
@@ -431,6 +431,7 @@ class Scheduler:
self
.
tree_cache
=
HiRadixCache
(
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_cache_group
=
self
.
tp_worker
.
get_tp_cpu_group
(),
)
)
else
:
else
:
self
.
tree_cache
=
RadixCache
(
self
.
tree_cache
=
RadixCache
(
...
@@ -1005,6 +1006,11 @@ class Scheduler:
...
@@ -1005,6 +1006,11 @@ class Scheduler:
self
.
batch_is_full
=
True
self
.
batch_is_full
=
True
return
None
return
None
if
self
.
enable_hierarchical_cache
:
# check for completion of hierarchical cache activities to release memory
self
.
tree_cache
.
writing_check
()
self
.
tree_cache
.
loading_check
()
# Get priority queue
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
...
@@ -1048,32 +1054,14 @@ class Scheduler:
...
@@ -1048,32 +1054,14 @@ class Scheduler:
self
.
batch_is_full
=
True
self
.
batch_is_full
=
True
break
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
,
self
.
enable_hierarchical_cache
,
)
if
self
.
enable_hierarchical_cache
and
req
.
last_node
is
not
None
:
res
=
adder
.
add_one_req
(
if
req
.
last_node
.
evicted
:
req
,
self
.
chunked_req
,
self
.
enable_hierarchical_cache
# loading KV cache for the request
)
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node
,
req
.
prefix_indices
,
adder
.
rem_total_tokens
,
)
if
req
.
last_node
.
loading
:
# to prevent frequent cache invalidation
if
req
.
rid
in
self
.
staging_reqs
:
self
.
tree_cache
.
dec_lock_ref
(
self
.
staging_reqs
[
req
.
rid
])
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
staging_reqs
[
req
.
rid
]
=
req
.
last_node
continue
elif
req
.
last_node
.
loading
:
if
not
self
.
tree_cache
.
loading_complete
(
req
.
last_node
):
continue
if
req
.
rid
in
self
.
staging_reqs
:
self
.
tree_cache
.
dec_lock_ref
(
self
.
staging_reqs
[
req
.
rid
])
del
self
.
staging_reqs
[
req
.
rid
]
res
=
adder
.
add_one_req
(
req
,
self
.
chunked_req
)
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
==
AddReqResult
.
NO_TOKEN
:
if
res
==
AddReqResult
.
NO_TOKEN
:
if
self
.
enable_hierarchical_cache
:
if
self
.
enable_hierarchical_cache
:
...
@@ -1094,6 +1082,9 @@ class Scheduler:
...
@@ -1094,6 +1082,9 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
]
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
read_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
assert
self
.
chunked_req
is
None
self
.
chunked_req
=
adder
.
new_chunked_req
self
.
chunked_req
=
adder
.
new_chunked_req
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
10b544ae
import
heapq
import
heapq
import
logging
import
logging
import
threading
import
time
import
time
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPoolHost
,
MHATokenToKVPoolHost
,
ReqToTokenPool
,
ReqToTokenPool
,
...
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
...
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
):
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool_allocator
.
get_kvcache
()
token_to_kv_pool_allocator
.
get_kvcache
()
)
)
self
.
tp_group
=
tp_cache_group
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
load_cache_event
=
self
.
load_cache_event
,
)
)
# record the nodes with ongoing write through
# record the nodes with ongoing write through
...
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
...
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def
write_backup
(
self
,
node
:
TreeNode
):
def
write_backup
(
self
,
node
:
TreeNode
):
host_indices
=
self
.
cache_controller
.
write
(
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
device_indices
=
node
.
value
,
priority
=-
self
.
get_height
(
node
),
node_id
=
node
.
id
,
node_id
=
node
.
id
,
)
)
if
host_indices
is
None
:
if
host_indices
is
None
:
self
.
evict_host
(
len
(
node
.
value
))
self
.
evict_host
(
len
(
node
.
value
))
host_indices
=
self
.
cache_controller
.
write
(
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
device_indices
=
node
.
value
,
priority
=-
self
.
get_height
(
node
),
node_id
=
node
.
id
,
node_id
=
node
.
id
,
)
)
if
host_indices
is
not
None
:
if
host_indices
is
not
None
:
...
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
...
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node
.
hit_count
=
0
node
.
hit_count
=
0
def
writing_check
(
self
):
def
writing_check
(
self
):
while
not
self
.
cache_controller
.
ack_write_queue
.
empty
():
queue_size
=
torch
.
tensor
(
try
:
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get_nowait
()
)
self
.
dec_lock_ref
(
self
.
ongoing_write_through
[
ack_id
])
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# clear the reference
# synchrnoize TP workers to make the same update to radix cache
del
self
.
ongoing_write_through
[
ack_id
]
torch
.
distributed
.
all_reduce
(
except
Exception
:
queue_size
,
break
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
self
.
dec_lock_ref
(
self
.
ongoing_write_through
[
ack_id
])
del
self
.
ongoing_write_through
[
ack_id
]
def
loading_check
(
self
):
def
loading_check
(
self
):
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
...
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
...
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break
break
def
evictable_size
(
self
):
def
evictable_size
(
self
):
self
.
writing_check
()
self
.
loading_check
()
return
self
.
evictable_size_
return
self
.
evictable_size_
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
=
None
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
=
None
):
...
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
...
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return
device_indices
return
device_indices
def
loading_complete
(
self
,
node
:
TreeNode
):
self
.
loading_check
()
return
node
.
loading
==
False
def
init_load_back
(
def
init_load_back
(
self
,
self
,
last_node
:
TreeNode
,
last_node
:
TreeNode
,
...
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
...
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
return
last_node
,
prefix_indices
def
read_to_load_cache
(
self
):
self
.
load_cache_event
.
set
()
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
if
self
.
disable
:
return
[],
self
.
root_node
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
if
value
:
value
=
torch
.
concat
(
value
)
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
last_node_global
=
last_node
while
last_node
.
evicted
:
last_node
=
last_node
.
parent
if
include_evicted
:
return
value
,
last_node
,
last_node_global
else
:
return
value
,
last_node
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
value
=
[]
value
=
[]
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
10b544ae
...
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
...
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self
.
layer_num
=
layer_num
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
self
.
_create_buffers
()
self
.
layer_transfer_counter
=
None
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
...
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
...
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
self
.
k_buffer
[
layer_id
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
][
indices
]
=
v_data
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
]
return
self
.
v_buffer
[
layer_id
]
...
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
...
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def
get_flat_data
(
self
,
indices
):
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment