Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
317631ca
Unverified
Commit
317631ca
authored
Oct 02, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 02, 2024
Browse files
[Fix] Move ScheduleBatch out of SamplingInfo (#1556)
parent
b5648353
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
10 deletions
+19
-10
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+15
-2
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+4
-8
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
317631ca
...
...
@@ -423,10 +423,14 @@ class ScheduleBatch:
# Stream
has_stream
:
bool
=
False
# Has regex
has_regex
:
bool
=
False
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
has_stream
=
any
(
req
.
stream
for
req
in
reqs
)
has_regex
=
any
(
req
.
regex_fsm
for
req
in
reqs
)
return
cls
(
reqs
=
reqs
,
...
...
@@ -435,6 +439,7 @@ class ScheduleBatch:
tree_cache
=
tree_cache
,
return_logprob
=
return_logprob
,
has_stream
=
has_stream
,
has_regex
=
has_regex
,
)
def
batch_size
(
self
):
...
...
@@ -750,7 +755,9 @@ class ScheduleBatch:
]
else
:
self
.
top_logprobs_nums
=
None
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
has_regex
=
any
(
req
.
regex_fsm
for
req
in
self
.
reqs
)
self
.
sampling_info
.
filter_batch
(
unfinished_indices
,
new_indices
)
...
...
@@ -771,9 +778,11 @@ class ScheduleBatch:
self
.
top_logprobs_nums
.
extend
([
0
]
*
len
(
other
.
reqs
))
elif
other
.
return_logprob
:
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
has_stream
=
any
(
req
.
stream
for
req
in
self
.
reqs
)
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
return_logprob
=
self
.
return_logprob
or
other
.
return_logprob
self
.
has_stream
=
self
.
has_stream
or
other
.
has_stream
self
.
has_regex
=
self
.
has_regex
or
other
.
has_regex
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
():
...
...
@@ -787,7 +796,11 @@ class ScheduleBatch:
image_inputs
=
[
r
.
image_inputs
for
r
in
self
.
reqs
]
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsm_states
=
[
req
.
regex_fsm_state
for
req
in
self
.
reqs
]
if
self
.
has_regex
:
self
.
sampling_info
.
regex_fsms
=
[
req
.
regex_fsm
for
req
in
self
.
reqs
]
self
.
sampling_info
.
regex_fsm_states
=
[
req
.
regex_fsm_state
for
req
in
self
.
reqs
]
return
ModelWorkerBatch
(
forward_mode
=
self
.
forward_mode
,
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
317631ca
...
...
@@ -84,10 +84,6 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
# This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm
# in SamplingBatchInfo, so we keep it here.
ret
.
schedule_batch
=
batch
return
ret
def
__len__
(
self
):
...
...
@@ -113,7 +109,7 @@ class SamplingBatchInfo:
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
):
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
schedule_batch
.
req
s
)
has_regex
=
self
.
regex_fsm
s
and
any
(
regex_fsm
for
regex_fsm
in
self
.
regex_fsm
s
)
# Reset the vocab mask
self
.
vocab_mask
=
None
...
...
@@ -122,11 +118,11 @@ class SamplingBatchInfo:
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
temperatures
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
for
i
,
re
q
in
enumerate
(
self
.
schedule_batch
.
req
s
):
if
req
.
regex_fsm
is
not
None
:
for
i
,
re
gex_fsm
in
enumerate
(
self
.
regex_fsm
s
):
if
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_state
s
[
i
]
).
tokens
]
=
0
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment