Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9dae4078
"vscode:/vscode.git/clone" did not exist on "fbc8d7685d68d4815346ffddfe23ce899664c6f5"
Unverified
Commit
9dae4078
authored
Aug 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 11, 2024
Browse files
Improve type annotation (#1029)
parent
fcc0f5ed
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
83 additions
and
41 deletions
+83
-41
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+17
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+28
-10
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+4
-3
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+10
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+22
-12
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
9dae4078
...
...
@@ -18,13 +18,15 @@ limitations under the License.
import
random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
List
from
typing
import
Dict
,
List
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
class
PolicyScheduler
:
def
__init__
(
self
,
policy
,
tree_cache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
...
...
@@ -72,12 +74,18 @@ class PolicyScheduler:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
get_dfs_priority
(
self
,
cur_node
,
node_to_priority
,
last_node_to_reqs
,
q
):
def
get_dfs_priority
(
self
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
,
last_node_to_reqs
:
Dict
,
q
:
List
,
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
...
...
@@ -88,10 +96,10 @@ class PolicyScheduler:
class
PrefillAdder
:
def
__init__
(
self
,
tree_cache
,
rem_total_tokens
,
rem_input_tokens
,
rem_chunk_tokens
,
tree_cache
:
BasePrefixCache
,
rem_total_tokens
:
int
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
int
,
):
self
.
tree_cache
=
tree_cache
self
.
rem_total_tokens
=
rem_total_tokens
...
...
@@ -151,7 +159,7 @@ class PrefillAdder:
return
req
if
truncated
else
None
@
contextmanager
def
_lock_node
(
self
,
last_node
):
def
_lock_node
(
self
,
last_node
:
TreeNode
):
try
:
delta
=
self
.
tree_cache
.
inc_lock_ref
(
last_node
)
self
.
rem_total_tokens
+=
delta
...
...
python/sglang/srt/managers/tp_worker.py
View file @
9dae4078
...
...
@@ -21,15 +21,17 @@ import os
import
pickle
import
time
import
warnings
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
import
torch.distributed
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
...
...
@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Rename "CI" to "SGLANG_IS_IN_CI".
crash_on_warning
=
os
.
getenv
(
"CI"
,
"false"
)
==
"true"
class
ModelTpServer
:
def
__init__
(
self
,
...
...
@@ -198,7 +204,7 @@ class ModelTpServer:
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
def
exposed_step
(
self
,
recv_reqs
):
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
...
...
@@ -247,7 +253,7 @@ class ModelTpServer:
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_stats
()
self
.
print_
decode_
stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
...
...
@@ -259,7 +265,7 @@ class ModelTpServer:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_stats
(
self
):
def
print_
decode_
stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
...
...
@@ -276,7 +282,6 @@ class ModelTpServer:
)
def
check_memory
(
self
):
crash
=
os
.
getenv
(
"CI"
,
"false"
)
==
"true"
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
...
...
@@ -286,7 +291,7 @@ class ModelTpServer:
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash
else
None
exit
(
1
)
if
crash
_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
...
...
@@ -295,7 +300,7 @@ class ModelTpServer:
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash
else
None
exit
(
1
)
if
crash
_on_warning
else
None
def
handle_generate_request
(
self
,
...
...
@@ -511,7 +516,14 @@ class ModelTpServer:
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
,
next_token_ids
,
output
):
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitProcessorOutput
,
):
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
...
...
@@ -786,7 +798,11 @@ def run_tp_server(
def
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
nccl_port
,
model_overide_args
gpu_ids
:
List
[
int
],
tp_rank_range
:
List
[
int
],
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_overide_args
:
dict
,
):
"""Launch multiple tensor parallel servers."""
procs
=
[]
...
...
@@ -801,7 +817,9 @@ def launch_tp_servers(
return
procs
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
def
broadcast_recv_input
(
data
:
Any
,
rank
:
int
,
dist_group
:
torch
.
distributed
.
ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if
rank
==
0
:
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
9dae4078
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
class
BasePrefixCache
(
ABC
):
...
...
@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
pass
@
abstractmethod
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
pass
@
abstractmethod
...
...
@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
pass
def
total_size
(
self
):
raise
NotImplementedError
raise
NotImplementedError
()
def
pretty_print
(
self
):
raise
NotImplementedError
raise
NotImplementedError
()
python/sglang/srt/mem_cache/chunk_cache.py
View file @
9dae4078
from
__future__
import
annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Callable
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -15,7 +18,9 @@ class ChunkCacheEntry:
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
):
self
.
disable
=
True
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
...
...
@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
def
cache_finished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
...
...
@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
if
req
.
rid
in
self
.
entries
:
del
self
.
entries
[
req
.
rid
]
def
cache_unfinished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
token_ids
=
req
.
fill_ids
...
...
@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
def
insert
(
self
):
raise
NotImplementedError
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
pass
def
inc_lock_ref
(
self
,
node
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
9dae4078
...
...
@@ -16,7 +16,7 @@ limitations under the License.
"""Memory pool."""
import
logging
from
typing
import
List
from
typing
import
List
,
Union
import
torch
...
...
@@ -42,7 +42,7 @@ class ReqToTokenPool:
return
select_index
def
free
(
self
,
free_index
):
def
free
(
self
,
free_index
:
Union
[
int
,
List
[
int
]]
):
if
isinstance
(
free_index
,
(
int
,)):
self
.
free_slots
.
append
(
free_index
)
else
:
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
9dae4078
from
__future__
import
annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
import
torch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -43,7 +46,7 @@ class TreeNode:
return
self
.
last_access_time
<
other
.
last_access_time
def
_key_match
(
key0
,
key1
):
def
_key_match
(
key0
:
List
,
key1
:
List
):
i
=
0
for
k0
,
k1
in
zip
(
key0
,
key1
):
if
k0
!=
k1
:
...
...
@@ -53,7 +56,12 @@ def _key_match(key0, key1):
class
RadixCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
,
disable
:
bool
=
False
,
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
disable
=
disable
...
...
@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
,
**
kwargs
):
if
self
.
disable
:
return
[],
self
.
root_node
...
...
@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
return
value
,
last_node
[
0
]
def
insert
(
self
,
key
,
value
=
None
):
def
insert
(
self
,
key
:
List
,
value
=
None
):
if
self
.
disable
:
return
0
...
...
@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it finishes."""
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
...
...
@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it is unfinished."""
if
self
.
disable
:
return
...
...
@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
def
total_size
(
self
):
return
self
.
_total_size_helper
(
self
.
root_node
)
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
if
self
.
disable
:
return
...
...
@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions #####
def
_match_prefix_helper
(
self
,
node
,
key
,
value
,
last_node
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
,
last_node
:
TreeNode
):
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
return
...
...
@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
last_node
[
0
]
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# new_node -> child
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
:][
0
]:
child
}
...
...
@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
new_node
.
parent
.
children
[
key
[:
split_len
][
0
]]
=
new_node
return
new_node
def
_insert_helper
(
self
,
node
,
key
,
value
):
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
return
0
...
...
@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
self
.
evictable_size_
+=
len
(
value
)
return
0
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
):
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
:
int
):
for
_
,
child
in
node
.
children
.
items
():
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
...
...
@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
del
node
.
parent
.
children
[
k
]
self
.
evictable_size_
-=
len
(
node
.
key
)
def
_total_size_helper
(
self
,
node
):
def
_total_size_helper
(
self
,
node
:
TreeNode
):
x
=
len
(
node
.
value
)
for
child
in
node
.
children
.
values
():
x
+=
self
.
_total_size_helper
(
child
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment