Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9376ac36
Unverified
Commit
9376ac36
authored
Mar 07, 2025
by
Zhiqiang Xie
Committed by
GitHub
Mar 07, 2025
Browse files
Memory pool fix for upstream change about eagle (#4170)
parent
94a2b9d3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
27 deletions
+27
-27
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+9
-6
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+17
-16
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+0
-4
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
9376ac36
...
...
@@ -22,7 +22,10 @@ from typing import List, Optional
import
torch
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
MHATokenToKVPoolHost
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPoolHost
,
TokenToKVPoolAllocator
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -127,12 +130,12 @@ class HiCacheController:
def
__init__
(
self
,
mem_pool_device
:
MHA
TokenToKVPool
,
token_to_kv_pool_allocator
:
TokenToKVPool
Allocator
,
mem_pool_host
:
MHATokenToKVPoolHost
,
write_policy
:
str
=
"write_through_selective"
,
):
self
.
mem_pool_device
=
mem_pool_device
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
...
...
@@ -216,7 +219,7 @@ class HiCacheController:
"""
Load KV caches from host memory to device memory.
"""
device_indices
=
self
.
mem_pool_device
.
alloc
(
len
(
host_indices
))
device_indices
=
self
.
mem_pool_device
_allocator
.
alloc
(
len
(
host_indices
))
if
device_indices
is
None
:
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
...
...
@@ -417,7 +420,7 @@ class HiCacheController:
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
)
->
int
:
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
self
.
mem_pool_device
.
free
(
device_indices
)
self
.
mem_pool_device
_allocator
.
free
(
device_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
return
len
(
device_indices
)
else
:
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
9376ac36
...
...
@@ -7,9 +7,9 @@ import torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MHATokenToKVPoolHost
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
,
_key_match
...
...
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
MHA
TokenToKVPool
,
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
,
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool
)
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool_allocator
.
get_kvcache
()
)
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool
,
self
.
token_to_kv_pool_host
token_to_kv_pool
_allocator
,
self
.
token_to_kv_pool_host
)
# record the nodes with ongoing write through
...
...
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
1
self
.
load_back_threshold
=
10
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool
,
disable
=
False
)
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool
_allocator
,
disable
=
False
)
def
reset
(
self
):
TreeNode
.
counter
=
0
...
...
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
def
_evict_write_through_selective
(
self
,
node
:
TreeNode
):
# evict a node not initiated write to host
self
.
cache_controller
.
mem_pool_device
.
free
(
node
.
value
)
self
.
cache_controller
.
mem_pool_device
_allocator
.
free
(
node
.
value
)
num_evicted
=
len
(
node
.
value
)
self
.
_delete_leaf
(
node
)
return
num_evicted
...
...
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
,
last_node
:
TreeNode
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
return
if
key
[
0
]
in
node
.
children
.
keys
():
value
=
[]
while
len
(
key
)
>
0
and
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
child
.
last_access_time
=
time
.
time
()
prefix_len
=
_key_match
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
self
.
inc_hit_count
(
new_node
)
if
not
new_node
.
evicted
:
value
.
append
(
new_node
.
value
)
last_node
[
0
]
=
new_node
node
=
new_node
break
else
:
self
.
inc_hit_count
(
child
)
if
not
child
.
evicted
:
value
.
append
(
child
.
value
)
last_node
[
0
]
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
node
=
child
key
=
key
[
prefix_len
:]
return
value
,
node
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# child node split into new_node -> child
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
9376ac36
...
...
@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
def
__init__
(
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
=
2
.0
,
host_to_device_ratio
:
float
=
3
.0
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
):
...
...
python/sglang/utils.py
View file @
9376ac36
...
...
@@ -24,14 +24,10 @@ import requests
from
IPython.display
import
HTML
,
display
from
tqdm
import
tqdm
from
sglang.srt.openai_api.protocol
import
ChatCompletionMessageContentPart
from
sglang.srt.utils
import
kill_process_tree
logger
=
logging
.
getLogger
(
__name__
)
# type of content fields, can be only prompts or with images/videos
MsgContent
=
Union
[
str
,
List
[
ChatCompletionMessageContentPart
]]
def
get_exception_traceback
():
etype
,
value
,
tb
=
sys
.
exc_info
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment