Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9dae4078
"G" did not exist on "b89700812254c0d6a68e848362a13ff15a28ade2"
Unverified
Commit
9dae4078
authored
Aug 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 11, 2024
Browse files
Improve type annotation (#1029)
parent
fcc0f5ed
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
83 additions
and
41 deletions
+83
-41
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+17
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+28
-10
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+4
-3
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+10
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+22
-12
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
9dae4078
...
...
@@ -18,13 +18,15 @@ limitations under the License.
import
random
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
List
from
typing
import
Dict
,
List
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.radix_cache
import
TreeNode
class
PolicyScheduler
:
def
__init__
(
self
,
policy
,
tree_cache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
if
tree_cache
.
disable
and
policy
in
[
"lpm"
,
"dfs-weight"
]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy
=
"fcfs"
...
...
@@ -72,12 +74,18 @@ class PolicyScheduler:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
def
calc_weight
(
self
,
cur_node
,
node_to_weight
):
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
node_to_weight
[
cur_node
]
+=
node_to_weight
[
child
]
def
get_dfs_priority
(
self
,
cur_node
,
node_to_priority
,
last_node_to_reqs
,
q
):
def
get_dfs_priority
(
self
,
cur_node
:
TreeNode
,
node_to_priority
:
Dict
,
last_node_to_reqs
:
Dict
,
q
:
List
,
):
childs
=
[
child
for
child
in
cur_node
.
children
.
values
()]
childs
.
sort
(
key
=
lambda
x
:
-
node_to_priority
[
x
])
for
child
in
childs
:
...
...
@@ -88,10 +96,10 @@ class PolicyScheduler:
class
PrefillAdder
:
def
__init__
(
self
,
tree_cache
,
rem_total_tokens
,
rem_input_tokens
,
rem_chunk_tokens
,
tree_cache
:
BasePrefixCache
,
rem_total_tokens
:
int
,
rem_input_tokens
:
int
,
rem_chunk_tokens
:
int
,
):
self
.
tree_cache
=
tree_cache
self
.
rem_total_tokens
=
rem_total_tokens
...
...
@@ -151,7 +159,7 @@ class PrefillAdder:
return
req
if
truncated
else
None
@
contextmanager
def
_lock_node
(
self
,
last_node
):
def
_lock_node
(
self
,
last_node
:
TreeNode
):
try
:
delta
=
self
.
tree_cache
.
inc_lock_ref
(
last_node
)
self
.
rem_total_tokens
+=
delta
...
...
python/sglang/srt/managers/tp_worker.py
View file @
9dae4078
...
...
@@ -21,15 +21,17 @@ import os
import
pickle
import
time
import
warnings
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
import
torch.distributed
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
...
...
@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Rename "CI" to "SGLANG_IS_IN_CI".
crash_on_warning
=
os
.
getenv
(
"CI"
,
"false"
)
==
"true"
class
ModelTpServer
:
def
__init__
(
self
,
...
...
@@ -198,7 +204,7 @@ class ModelTpServer:
self
.
new_token_ratio
=
self
.
min_new_token_ratio
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
def
exposed_step
(
self
,
recv_reqs
):
def
exposed_step
(
self
,
recv_reqs
:
List
):
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
...
...
@@ -247,7 +253,7 @@ class ModelTpServer:
# Print stats
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_stats
()
self
.
print_
decode_
stats
()
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
...
...
@@ -259,7 +265,7 @@ class ModelTpServer:
self
.
check_memory
()
self
.
new_token_ratio
=
global_config
.
init_new_token_ratio
def
print_stats
(
self
):
def
print_
decode_
stats
(
self
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
...
...
@@ -276,7 +282,6 @@ class ModelTpServer:
)
def
check_memory
(
self
):
crash
=
os
.
getenv
(
"CI"
,
"false"
)
==
"true"
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
...
...
@@ -286,7 +291,7 @@ class ModelTpServer:
f
"available_size=
{
available_size
}
, max_total_num_tokens=
{
self
.
max_total_num_tokens
}
\n
"
"KV cache pool leak detected!"
)
exit
(
1
)
if
crash
else
None
exit
(
1
)
if
crash
_on_warning
else
None
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
warnings
.
warn
(
...
...
@@ -295,7 +300,7 @@ class ModelTpServer:
f
"total slots=
{
self
.
req_to_token_pool
.
size
}
\n
"
"Memory pool leak detected!"
)
exit
(
1
)
if
crash
else
None
exit
(
1
)
if
crash
_on_warning
else
None
def
handle_generate_request
(
self
,
...
...
@@ -511,7 +516,14 @@ class ModelTpServer:
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
,
next_token_ids
,
output
):
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
output
:
LogitProcessorOutput
,
):
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
...
...
@@ -786,7 +798,11 @@ def run_tp_server(
def
launch_tp_servers
(
gpu_ids
,
tp_rank_range
,
server_args
,
nccl_port
,
model_overide_args
gpu_ids
:
List
[
int
],
tp_rank_range
:
List
[
int
],
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_overide_args
:
dict
,
):
"""Launch multiple tensor parallel servers."""
procs
=
[]
...
...
@@ -801,7 +817,9 @@ def launch_tp_servers(
return
procs
def
broadcast_recv_input
(
data
,
rank
,
dist_group
):
def
broadcast_recv_input
(
data
:
Any
,
rank
:
int
,
dist_group
:
torch
.
distributed
.
ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if
rank
==
0
:
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
9dae4078
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
class
BasePrefixCache
(
ABC
):
...
...
@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
pass
@
abstractmethod
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
pass
@
abstractmethod
...
...
@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
pass
def
total_size
(
self
):
raise
NotImplementedError
raise
NotImplementedError
()
def
pretty_print
(
self
):
raise
NotImplementedError
raise
NotImplementedError
()
python/sglang/srt/mem_cache/chunk_cache.py
View file @
9dae4078
from
__future__
import
annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Callable
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -15,7 +18,9 @@ class ChunkCacheEntry:
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
):
self
.
disable
=
True
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
...
...
@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
def
cache_finished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
...
...
@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
if
req
.
rid
in
self
.
entries
:
del
self
.
entries
[
req
.
rid
]
def
cache_unfinished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
if
token_ids
is
None
:
token_ids
=
req
.
fill_ids
...
...
@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
def
insert
(
self
):
raise
NotImplementedError
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
pass
def
inc_lock_ref
(
self
,
node
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
9dae4078
...
...
@@ -16,7 +16,7 @@ limitations under the License.
"""Memory pool."""
import
logging
from
typing
import
List
from
typing
import
List
,
Union
import
torch
...
...
@@ -42,7 +42,7 @@ class ReqToTokenPool:
return
select_index
def
free
(
self
,
free_index
):
def
free
(
self
,
free_index
:
Union
[
int
,
List
[
int
]]
):
if
isinstance
(
free_index
,
(
int
,)):
self
.
free_slots
.
append
(
free_index
)
else
:
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
9dae4078
from
__future__
import
annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
import
torch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -43,7 +46,7 @@ class TreeNode:
return
self
.
last_access_time
<
other
.
last_access_time
def
_key_match
(
key0
,
key1
):
def
_key_match
(
key0
:
List
,
key1
:
List
):
i
=
0
for
k0
,
k1
in
zip
(
key0
,
key1
):
if
k0
!=
k1
:
...
...
@@ -53,7 +56,12 @@ def _key_match(key0, key1):
class
RadixCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
,
disable
:
bool
=
False
,
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
disable
=
disable
...
...
@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
,
**
kwargs
):
if
self
.
disable
:
return
[],
self
.
root_node
...
...
@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
return
value
,
last_node
[
0
]
def
insert
(
self
,
key
,
value
=
None
):
def
insert
(
self
,
key
:
List
,
value
=
None
):
if
self
.
disable
:
return
0
...
...
@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
value
=
[
x
for
x
in
key
]
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
)
def
cache_finished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it finishes."""
if
token_ids
is
None
:
token_ids
=
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
...
...
@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
"
Req
"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
token_ids
:
Optional
[
List
[
int
]]
=
None
):
"""Cache request when it is unfinished."""
if
self
.
disable
:
return
...
...
@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
def
total_size
(
self
):
return
self
.
_total_size_helper
(
self
.
root_node
)
def
evict
(
self
,
num_tokens
,
evict_callback
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
if
self
.
disable
:
return
...
...
@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions #####
def
_match_prefix_helper
(
self
,
node
,
key
,
value
,
last_node
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
,
last_node
:
TreeNode
):
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
return
...
...
@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
last_node
[
0
]
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# new_node -> child
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
:][
0
]:
child
}
...
...
@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
new_node
.
parent
.
children
[
key
[:
split_len
][
0
]]
=
new_node
return
new_node
def
_insert_helper
(
self
,
node
,
key
,
value
):
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
return
0
...
...
@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
self
.
evictable_size_
+=
len
(
value
)
return
0
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
):
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
:
int
):
for
_
,
child
in
node
.
children
.
items
():
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
...
...
@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
del
node
.
parent
.
children
[
k
]
self
.
evictable_size_
-=
len
(
node
.
key
)
def
_total_size_helper
(
self
,
node
):
def
_total_size_helper
(
self
,
node
:
TreeNode
):
x
=
len
(
node
.
value
)
for
child
in
node
.
children
.
values
():
x
+=
self
.
_total_size_helper
(
child
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment