Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b5648353
Unverified
Commit
b5648353
authored
Oct 02, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 02, 2024
Browse files
[Fix] do not maintain regex_fsm in SamplingBatchInfo (#1555)
parent
2c7d0a5b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+11
-10
No files found.
python/sglang/srt/sampling/sampling_batch_info.py
View file @
b5648353
...
...
@@ -59,7 +59,6 @@ class SamplingBatchInfo:
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
ret
.
regex_fsms
=
[
r
.
regex_fsm
for
r
in
reqs
]
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret
.
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
)
...
...
@@ -85,6 +84,10 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
# This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm
# in SamplingBatchInfo, so we keep it here.
ret
.
schedule_batch
=
batch
return
ret
def
__len__
(
self
):
...
...
@@ -110,18 +113,20 @@ class SamplingBatchInfo:
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
):
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
schedule_batch
.
reqs
)
# Reset the vocab mask
self
.
vocab_mask
=
None
if
any
(
regex_fsm
is
not
None
for
regex_fsm
in
self
.
regex_fsms
)
:
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
regex_fsm
s
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
len
(
self
.
temperature
s
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
for
i
,
re
gex_fsm
in
enumerate
(
self
.
regex_fsm
s
):
if
regex_fsm
is
not
None
:
for
i
,
re
q
in
enumerate
(
self
.
schedule_batch
.
req
s
):
if
req
.
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
regex_fsm
.
get_next_instruction
(
self
.
regex_fsm_state
s
[
i
]
).
tokens
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
]
=
0
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
...
...
@@ -138,8 +143,6 @@ class SamplingBatchInfo:
if
value
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
value
[
new_indices
])
self
.
regex_fsms
=
[
self
.
regex_fsms
[
i
]
for
i
in
new_indices
]
@
staticmethod
def
merge_bias_tensor
(
lhs
:
torch
.
Tensor
,
rhs
:
torch
.
Tensor
,
bs1
:
int
,
bs2
:
int
,
default
:
int
=
0
...
...
@@ -176,5 +179,3 @@ class SamplingBatchInfo:
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
)
self
.
regex_fsms
.
extend
(
other
.
regex_fsms
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment