Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
47f20da2
"tutorials/vscode:/vscode.git/clone" did not exist on "53b9a4bdbc36e5253adfbb780dacccffa66c4fb7"
Unverified
Commit
47f20da2
authored
Sep 01, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 01, 2024
Browse files
Fix regex mask (#1296)
parent
4a9f8ea4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+1
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+5
-5
No files found.
python/sglang/srt/layers/sampler.py
View file @
47f20da2
...
@@ -63,7 +63,7 @@ class Sampler(CustomOp):
...
@@ -63,7 +63,7 @@ class Sampler(CustomOp):
logits
.
add_
(
sampling_info
.
logit_bias
)
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
~
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
logits
.
masked_fill
(
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
self
.
_apply_penalties
(
logits
,
sampling_info
)
logits
=
self
.
_apply_penalties
(
logits
,
sampling_info
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
47f20da2
...
@@ -154,15 +154,15 @@ class SamplingBatchInfo:
...
@@ -154,15 +154,15 @@ class SamplingBatchInfo:
self
.
vocab_mask
=
None
self
.
vocab_mask
=
None
if
has_regex
:
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
bs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
if
req
.
regex_fsm
is
not
None
:
if
req
.
regex_fsm
is
not
None
:
if
self
.
vocab_mask
is
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
=
torch
.
zeros
(
bs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
self
.
vocab_mask
[
i
][
self
.
vocab_mask
[
i
][
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
]
=
1
]
=
0
def
filter
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment