Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
08104b56
Unverified
Commit
08104b56
authored
Jan 27, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jan 27, 2025
Browse files
Sanity check to prevent performance regression (#3171)
Co-authored-by:
Lianmin Zheng
<
lianminzheng@gmail.com
>
parent
cf142b6e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
4 deletions
+60
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+17
-3
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+4
-0
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+3
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+30
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/managers/scheduler.py
View file @
08104b56
...
@@ -149,6 +149,7 @@ class Scheduler:
...
@@ -149,6 +149,7 @@ class Scheduler:
if
not
self
.
spec_algorithm
.
is_none
()
if
not
self
.
spec_algorithm
.
is_none
()
else
1
else
1
)
)
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
# Distributed rank info
# Distributed rank info
self
.
dp_size
=
server_args
.
dp_size
self
.
dp_size
=
server_args
.
dp_size
...
@@ -831,10 +832,16 @@ class Scheduler:
...
@@ -831,10 +832,16 @@ class Scheduler:
available_size
=
(
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
if
available_size
!=
self
.
max_total_num_tokens
:
protected_size
=
self
.
tree_cache
.
protected_size
()
memory_leak
=
available_size
!=
(
self
.
max_total_num_tokens
if
not
self
.
enable_hierarchical_cache
else
self
.
max_total_num_tokens
-
protected_size
)
if
memory_leak
:
msg
=
(
msg
=
(
"KV cache pool leak detected!"
"KV cache pool leak detected!"
f
"
{
available_size
=
}
,
{
self
.
max_total_num_tokens
=
}
\n
"
f
"
{
available_size
=
}
,
{
protected_size
=
}
,
{
self
.
max_total_num_tokens
=
}
\n
"
)
)
warnings
.
warn
(
msg
)
warnings
.
warn
(
msg
)
if
crash_on_warnings
():
if
crash_on_warnings
():
...
@@ -949,7 +956,14 @@ class Scheduler:
...
@@ -949,7 +956,14 @@ class Scheduler:
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
==
AddReqResult
.
NO_TOKEN
:
if
res
==
AddReqResult
.
NO_TOKEN
:
self
.
batch_is_full
=
True
if
self
.
enable_hierarchical_cache
:
# Set batch_is_full after making sure there are requests that can be served
self
.
batch_is_full
=
len
(
adder
.
can_run_list
)
>
0
or
(
self
.
running_batch
is
not
None
and
not
self
.
running_batch
.
is_empty
()
)
else
:
self
.
batch_is_full
=
True
break
break
if
self
.
server_args
.
prefill_only_one_req
:
if
self
.
server_args
.
prefill_only_one_req
:
break
break
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
08104b56
...
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
...
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
def
evictable_size
(
self
):
def
evictable_size
(
self
):
pass
pass
@
abstractmethod
def
protected_size
(
self
):
raise
NotImplementedError
()
def
total_size
(
self
):
def
total_size
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
08104b56
...
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
...
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
def
evictable_size
(
self
):
def
evictable_size
(
self
):
return
0
return
0
def
protected_size
(
self
):
return
0
python/sglang/srt/mem_cache/radix_cache.py
View file @
08104b56
...
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
...
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
class
TreeNode
:
class
TreeNode
:
def
__init__
(
self
):
counter
=
0
def
__init__
(
self
,
id
:
Optional
[
int
]
=
None
):
self
.
children
=
defaultdict
(
TreeNode
)
self
.
children
=
defaultdict
(
TreeNode
)
self
.
parent
=
None
self
.
parent
=
None
self
.
key
=
None
self
.
key
=
None
...
@@ -42,6 +45,23 @@ class TreeNode:
...
@@ -42,6 +45,23 @@ class TreeNode:
self
.
lock_ref
=
0
self
.
lock_ref
=
0
self
.
last_access_time
=
time
.
time
()
self
.
last_access_time
=
time
.
time
()
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
self
.
loading
=
False
# store the host indices of KV cache
self
.
host_value
=
None
self
.
id
=
TreeNode
.
counter
if
id
is
None
else
id
TreeNode
.
counter
+=
1
@
property
def
evicted
(
self
):
return
self
.
value
is
None
@
property
def
backuped
(
self
):
return
self
.
host_value
is
not
None
def
__lt__
(
self
,
other
:
"TreeNode"
):
def
__lt__
(
self
,
other
:
"TreeNode"
):
return
self
.
last_access_time
<
other
.
last_access_time
return
self
.
last_access_time
<
other
.
last_access_time
...
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
...
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
value
=
[]
self
.
root_node
.
value
=
[]
self
.
root_node
.
lock_ref
=
1
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
self
.
evictable_size_
=
0
self
.
protected_size_
=
0
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
int
]:
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Find the matching prefix from the radix tree.
"""Find the matching prefix from the radix tree.
...
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
...
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
while
node
!=
self
.
root_node
:
while
node
!=
self
.
root_node
:
if
node
.
lock_ref
==
0
:
if
node
.
lock_ref
==
0
:
self
.
evictable_size_
-=
len
(
node
.
value
)
self
.
evictable_size_
-=
len
(
node
.
value
)
self
.
protected_size_
+=
len
(
node
.
value
)
delta
-=
len
(
node
.
value
)
delta
-=
len
(
node
.
value
)
node
.
lock_ref
+=
1
node
.
lock_ref
+=
1
node
=
node
.
parent
node
=
node
.
parent
...
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
...
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
while
node
!=
self
.
root_node
:
while
node
!=
self
.
root_node
:
if
node
.
lock_ref
==
1
:
if
node
.
lock_ref
==
1
:
self
.
evictable_size_
+=
len
(
node
.
value
)
self
.
evictable_size_
+=
len
(
node
.
value
)
self
.
protected_size_
-=
len
(
node
.
value
)
delta
+=
len
(
node
.
value
)
delta
+=
len
(
node
.
value
)
node
.
lock_ref
-=
1
node
.
lock_ref
-=
1
node
=
node
.
parent
node
=
node
.
parent
...
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
...
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
def
evictable_size
(
self
):
def
evictable_size
(
self
):
return
self
.
evictable_size_
return
self
.
evictable_size_
def
protected_size
(
self
):
# protected size refers to the size of the cache that is locked
return
self
.
protected_size_
##### Internal Helper Functions #####
##### Internal Helper Functions #####
def
_match_prefix_helper
(
def
_match_prefix_helper
(
...
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
...
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
self
.
evictable_size_
-=
len
(
node
.
key
)
self
.
evictable_size_
-=
len
(
node
.
key
)
def
_total_size_helper
(
self
,
node
:
TreeNode
):
def
_total_size_helper
(
self
,
node
:
TreeNode
):
if
node
.
evicted
:
return
0
x
=
len
(
node
.
value
)
x
=
len
(
node
.
value
)
for
child
in
node
.
children
.
values
():
for
child
in
node
.
children
.
values
():
x
+=
self
.
_total_size_helper
(
child
)
x
+=
self
.
_total_size_helper
(
child
)
...
...
python/sglang/srt/server_args.py
View file @
08104b56
...
@@ -163,6 +163,7 @@ class ServerArgs:
...
@@ -163,6 +163,7 @@ class ServerArgs:
# Custom logit processor
# Custom logit processor
enable_custom_logit_processor
:
bool
=
False
enable_custom_logit_processor
:
bool
=
False
tool_call_parser
:
str
=
None
tool_call_parser
:
str
=
None
enable_hierarchical_cache
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set missing default values
# Set missing default values
...
@@ -892,6 +893,11 @@ class ServerArgs:
...
@@ -892,6 +893,11 @@ class ServerArgs:
default
=
ServerArgs
.
tool_call_parser
,
default
=
ServerArgs
.
tool_call_parser
,
help
=
"Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'."
,
help
=
"Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'."
,
)
)
parser
.
add_argument
(
"--enable-hierarchical-cache"
,
action
=
"store_true"
,
help
=
"Enable hierarchical cache"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment