Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9376ac36
Unverified
Commit
9376ac36
authored
Mar 07, 2025
by
Zhiqiang Xie
Committed by
GitHub
Mar 07, 2025
Browse files
Memory pool fix for upstream change about eagle (#4170)
parent
94a2b9d3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
27 deletions
+27
-27
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+9
-6
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+17
-16
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+0
-4
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
9376ac36
...
@@ -22,7 +22,10 @@ from typing import List, Optional
...
@@ -22,7 +22,10 @@ from typing import List, Optional
import
torch
import
torch
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
MHATokenToKVPoolHost
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPoolHost
,
TokenToKVPoolAllocator
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -127,12 +130,12 @@ class HiCacheController:
...
@@ -127,12 +130,12 @@ class HiCacheController:
def
__init__
(
def
__init__
(
self
,
self
,
mem_pool_device
:
MHA
TokenToKVPool
,
token_to_kv_pool_allocator
:
TokenToKVPool
Allocator
,
mem_pool_host
:
MHATokenToKVPoolHost
,
mem_pool_host
:
MHATokenToKVPoolHost
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
):
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device
=
mem_pool_device
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_host
=
mem_pool_host
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
write_policy
=
write_policy
...
@@ -216,7 +219,7 @@ class HiCacheController:
...
@@ -216,7 +219,7 @@ class HiCacheController:
"""
"""
Load KV caches from host memory to device memory.
Load KV caches from host memory to device memory.
"""
"""
device_indices
=
self
.
mem_pool_device
.
alloc
(
len
(
host_indices
))
device_indices
=
self
.
mem_pool_device
_allocator
.
alloc
(
len
(
host_indices
))
if
device_indices
is
None
:
if
device_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
self
.
mem_pool_host
.
protect_load
(
host_indices
)
...
@@ -417,7 +420,7 @@ class HiCacheController:
...
@@ -417,7 +420,7 @@ class HiCacheController:
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
)
->
int
:
)
->
int
:
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
if
self
.
mem_pool_host
.
is_synced
(
host_indices
):
self
.
mem_pool_device
.
free
(
device_indices
)
self
.
mem_pool_device
_allocator
.
free
(
device_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
self
.
mem_pool_host
.
update_backup
(
host_indices
)
return
len
(
device_indices
)
return
len
(
device_indices
)
else
:
else
:
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
9376ac36
...
@@ -7,9 +7,9 @@ import torch
...
@@ -7,9 +7,9 @@ import torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MHATokenToKVPoolHost
,
MHATokenToKVPoolHost
,
ReqToTokenPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
,
_key_match
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
,
_key_match
...
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
...
@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
MHA
TokenToKVPool
,
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
,
):
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool
)
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool_allocator
.
get_kvcache
()
)
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool
,
self
.
token_to_kv_pool_host
token_to_kv_pool
_allocator
,
self
.
token_to_kv_pool_host
)
)
# record the nodes with ongoing write through
# record the nodes with ongoing write through
...
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
...
@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
1
self
.
write_through_threshold
=
1
self
.
load_back_threshold
=
10
self
.
load_back_threshold
=
10
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool
,
disable
=
False
)
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool
_allocator
,
disable
=
False
)
def
reset
(
self
):
def
reset
(
self
):
TreeNode
.
counter
=
0
TreeNode
.
counter
=
0
...
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
...
@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
def
_evict_write_through_selective
(
self
,
node
:
TreeNode
):
def
_evict_write_through_selective
(
self
,
node
:
TreeNode
):
# evict a node not initiated write to host
# evict a node not initiated write to host
self
.
cache_controller
.
mem_pool_device
.
free
(
node
.
value
)
self
.
cache_controller
.
mem_pool_device
_allocator
.
free
(
node
.
value
)
num_evicted
=
len
(
node
.
value
)
num_evicted
=
len
(
node
.
value
)
self
.
_delete_leaf
(
node
)
self
.
_delete_leaf
(
node
)
return
num_evicted
return
num_evicted
...
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
...
@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
return
last_node
,
prefix_indices
def
_match_prefix_helper
(
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
self
,
node
:
TreeNode
,
key
:
List
,
value
,
last_node
:
TreeNode
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
value
=
[]
return
while
len
(
key
)
>
0
and
key
[
0
]
in
node
.
children
.
keys
():
if
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
child
=
node
.
children
[
key
[
0
]]
child
.
last_access_time
=
time
.
time
()
prefix_len
=
_key_match
(
child
.
key
,
key
)
prefix_len
=
_key_match
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
self
.
inc_hit_count
(
new_node
)
self
.
inc_hit_count
(
new_node
)
if
not
new_node
.
evicted
:
if
not
new_node
.
evicted
:
value
.
append
(
new_node
.
value
)
value
.
append
(
new_node
.
value
)
last_node
[
0
]
=
new_node
node
=
new_node
break
else
:
else
:
self
.
inc_hit_count
(
child
)
self
.
inc_hit_count
(
child
)
if
not
child
.
evicted
:
if
not
child
.
evicted
:
value
.
append
(
child
.
value
)
value
.
append
(
child
.
value
)
last_node
[
0
]
=
child
node
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
key
=
key
[
prefix_len
:]
return
value
,
node
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# child node split into new_node -> child
# child node split into new_node -> child
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
9376ac36
...
@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
...
@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
def
__init__
(
def
__init__
(
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
=
2
.0
,
host_to_device_ratio
:
float
=
3
.0
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
...
...
python/sglang/utils.py
View file @
9376ac36
...
@@ -24,14 +24,10 @@ import requests
...
@@ -24,14 +24,10 @@ import requests
from
IPython.display
import
HTML
,
display
from
IPython.display
import
HTML
,
display
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
sglang.srt.openai_api.protocol
import
ChatCompletionMessageContentPart
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# type of content fields, can be only prompts or with images/videos
MsgContent
=
Union
[
str
,
List
[
ChatCompletionMessageContentPart
]]
def
get_exception_traceback
():
def
get_exception_traceback
():
etype
,
value
,
tb
=
sys
.
exc_info
()
etype
,
value
,
tb
=
sys
.
exc_info
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment