Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7de60345
"tests/vscode:/vscode.git/clone" did not exist on "af48bf200860d8b83fe3be92b2d7ae556a3b4111"
Unverified
Commit
7de60345
authored
Aug 11, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 11, 2024
Browse files
Fix the prefix indices (#1037)
parent
d84c5e70
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
15 deletions
+25
-15
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+10
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+9
-6
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+5
-3
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+1
-1
No files found.
python/sglang/srt/managers/policy_scheduler.py
View file @
7de60345
...
@@ -43,11 +43,14 @@ class PolicyScheduler:
...
@@ -43,11 +43,14 @@ class PolicyScheduler:
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
]):
# Compute matched prefix length
# Compute matched prefix length
prefix_computed
=
False
if
self
.
policy
in
[
"lpm"
,
"dfs-weight"
]:
for
r
in
waiting_queue
:
for
r
in
waiting_queue
:
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
rid
=
r
.
rid
,
key
=
r
.
adjust_max_prefix_ids
()
)
)
prefix_computed
=
True
if
self
.
policy
==
"lpm"
:
if
self
.
policy
==
"lpm"
:
# Longest Prefix Match
# Longest Prefix Match
...
@@ -80,6 +83,8 @@ class PolicyScheduler:
...
@@ -80,6 +83,8 @@ class PolicyScheduler:
else
:
else
:
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
raise
ValueError
(
f
"Unknown schedule_policy:
{
self
.
policy
}
"
)
return
prefix_computed
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
def
calc_weight
(
self
,
cur_node
:
TreeNode
,
node_to_weight
:
Dict
):
for
child
in
cur_node
.
children
.
values
():
for
child
in
cur_node
.
children
.
values
():
self
.
calc_weight
(
child
,
node_to_weight
)
self
.
calc_weight
(
child
,
node_to_weight
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7de60345
...
@@ -18,9 +18,8 @@ limitations under the License.
...
@@ -18,9 +18,8 @@ limitations under the License.
import
logging
import
logging
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
...
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
...
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -164,8 +163,12 @@ class Req:
...
@@ -164,8 +163,12 @@ class Req:
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
...
@@ -312,7 +315,7 @@ class ScheduleBatch:
...
@@ -312,7 +315,7 @@ class ScheduleBatch:
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
token_to_kv_pool
:
BaseTokenToKVPool
tree_cache
:
Rad
ixCache
tree_cache
:
BasePref
ixCache
# Batched arguments to model runner
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
...
@@ -534,7 +537,7 @@ class ScheduleBatch:
...
@@ -534,7 +537,7 @@ class ScheduleBatch:
residual_size
=
max
(
0
,
residual_size
)
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
req
.
prefix_indices
=
None
req
.
prefix_indices
=
[]
req
.
last_node
=
None
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
extend_input_len
=
0
...
...
python/sglang/srt/managers/tp_worker.py
View file @
7de60345
...
@@ -369,7 +369,7 @@ class ModelTpServer:
...
@@ -369,7 +369,7 @@ class ModelTpServer:
return
None
return
None
# Get priority queue
# Get priority queue
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
prefix_computed
=
self
.
scheduler
.
calc_priority
(
self
.
waiting_queue
)
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
...
@@ -383,13 +383,15 @@ class ModelTpServer:
...
@@ -383,13 +383,15 @@ class ModelTpServer:
has_inflight
=
self
.
current_inflight_req
is
not
None
has_inflight
=
self
.
current_inflight_req
is
not
None
if
self
.
current_inflight_req
is
not
None
:
if
self
.
current_inflight_req
is
not
None
:
self
.
current_inflight_req
.
init_next_round_input
()
self
.
current_inflight_req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
=
adder
.
add_inflight_req
(
self
.
current_inflight_req
self
.
current_inflight_req
)
)
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
req
.
init_next_round_input
()
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
)
res
=
adder
.
add_one_req
(
req
)
if
(
if
(
not
res
not
res
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
7de60345
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
...
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import
heapq
import
heapq
import
time
import
time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment