Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7de60345
Unverified
Commit
7de60345
authored
Aug 11, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 11, 2024
Browse files
Fix the prefix indices (#1037)
parent
d84c5e70
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
15 deletions
+25
-15
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+10
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+9
-6
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-3
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+1
-1
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
7de60345
...
@@ -43,11 +43,14 @@ class PolicyScheduler:
...
@@ -43,11 +43,14 @@ class PolicyScheduler:
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
# Compute matched prefix length
# Compute matched prefix length
for
r
in
waiting_queue
:
prefix_computed
=
False
# NOTE: the prefix_indices must always be aligned with last_node
if
self
.
policy
in
[
"lpm"
,
"dfs-weight"
]:
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
for
r
in
waiting_queue
:
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
)
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
prefix_computed
=
True
if
self
.
policy
==
"lpm"
:
if
self
.
policy
==
"lpm"
:
# Longest Prefix Match
# Longest Prefix Match
...
@@ -80,6 +83,8 @@ class PolicyScheduler:
...
@@ -80,6 +83,8 @@ class PolicyScheduler:
else
:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
return
prefix_computed
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
for
child
in
cur_node
.
children
.
values
():
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
self
.
calc_weight
(
child
,
node_to_weight
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7de60345
...
@@ -18,9 +18,8 @@ limitations under the License.
...
@@ -18,9 +18,8 @@ limitations under the License.
import
logging
import
logging
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
...
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
...
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -164,8 +163,12 @@ class Req:
...
@@ -164,8 +163,12 @@ class Req:
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
...
@@ -312,7 +315,7 @@ class ScheduleBatch:
...
@@ -312,7 +315,7 @@ class ScheduleBatch:
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
token_to_kv_pool
:
BaseTokenToKVPool
tree_cache
:
Rad
ixCache
tree_cache
:
BasePref
ixCache
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
...
@@ -534,7 +537,7 @@ class ScheduleBatch:
...
@@ -534,7 +537,7 @@ class ScheduleBatch:
residual_size
=
max
(
0
,
residual_size
)
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
req
.
prefix_indices
=
None
req
.
prefix_indices
=
[]
req
.
last_node
=
None
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
extend_input_len
=
0
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7de60345
...
@@ -369,7 +369,7 @@ class ModelTpServer:
...
@@ -369,7 +369,7 @@ class ModelTpServer:
return
None
return
None
# Get priority queue
# Get priority queue
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
...
@@ -383,13 +383,15 @@ class ModelTpServer:
...
@@ -383,13 +383,15 @@ class ModelTpServer:
has_inflight
=
self
.
current_inflight_req
is
not
None
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
()
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
self
.
current_inflight_req
)
)
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
req
.
init_next_round_input
()
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
(
if
(
not
res
not
res
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
7de60345
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
heapq
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment