Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7de60345
Unverified
Commit
7de60345
authored
Aug 11, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 11, 2024
Browse files
Fix the prefix indices (#1037)
parent
d84c5e70
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
15 deletions
+25
-15
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+10
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+9
-6
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-3
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+1
-1
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
7de60345
...
...
@@ -43,11 +43,14 @@ class PolicyScheduler:
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
# Compute matched prefix length
for
r
in
waiting_queue
:
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
prefix_computed
=
False
if
self
.
policy
in
[
"lpm"
,
"dfs-weight"
]:
for
r
in
waiting_queue
:
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
prefix_computed
=
True
if
self
.
policy
==
"lpm"
:
# Longest Prefix Match
...
...
@@ -80,6 +83,8 @@ class PolicyScheduler:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
return
prefix_computed
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7de60345
...
...
@@ -18,9 +18,8 @@ limitations under the License.
import
logging
import
warnings
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
torch
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
...
...
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -164,8 +163,12 @@ class Req:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
...
...
@@ -312,7 +315,7 @@ class ScheduleBatch:
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
tree_cache
:
Rad
ixCache
tree_cache
:
BasePref
ixCache
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
...
...
@@ -534,7 +537,7 @@ class ScheduleBatch:
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
req
.
prefix_indices
=
None
req
.
prefix_indices
=
[]
req
.
last_node
=
None
req
.
extend_input_len
=
0
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7de60345
...
...
@@ -369,7 +369,7 @@ class ModelTpServer:
return
None
# Get priority queue
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
adder
=
PrefillAdder
(
self
.
tree_cache
,
...
...
@@ -383,13 +383,15 @@ class ModelTpServer:
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
()
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
)
for
req
in
self
.
waiting_queue
:
req
.
init_next_round_input
()
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
if
(
not
res
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
7de60345
...
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment