Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10ef65ed
Unverified
Commit
10ef65ed
authored
Jan 07, 2026
by
Nick Hill
Committed by
GitHub
Jan 07, 2026
Browse files
[BugFix] Fix bad words with speculative decoding (#31908)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
6170d47d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
22 deletions
+26
-22
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+13
-14
vllm/v1/sample/ops/bad_words.py
vllm/v1/sample/ops/bad_words.py
+13
-8
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
10ef65ed
...
...
@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler):
def
test_bad_words
(
rejection_sampler
):
"""Test rejection sampling with bad words constraints"""
"""Test rejection sampling with bad words constraints.
This test applies bad words to non-consecutive requests (0 and 2, but not 1)
to verify correct logit indexing when iterating over requests with bad words.
"""
spec_tokens
=
[[
1
,
2
,
3
],
[
1
,
15
,
3
],
[
1
,
2
,
3
]]
output_tokens
=
[[
1
,
2
,
3
,
4
],
[
1
,
2
,
3
,
4
],
[
1
,
2
,
3
,
4
]]
output_tokens
=
[[
1
,
2
,
3
,
4
],
[
1
,
15
,
3
,
4
],
[
1
,
2
,
3
,
4
]]
logits
=
create_logits_tensor
(
output_tokens
,
token_idx_to_override
=
15
)
metadata
=
create_sampling_metadata
(
...
...
@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler):
output_token_ids
=
[[
2
],
[
3
],
[
4
]],
spec_token_ids
=
spec_tokens
,
bad_words_token_ids
=
{
0
:
[
[
2
,
]
],
1
:
[
[
2
,
]
],
# Do not apply bad words to the last request
0
:
[[
2
]],
# Request 1 has no bad words (to test non-consecutive request handling)
2
:
[[
2
]],
},
)
bonus_token_tensor
=
torch
.
tensor
(
...
...
@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler):
sampling_metadata
=
metadata
,
)
# Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
# Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
# Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
expected
=
torch
.
tensor
(
[[
1
,
15
,
-
1
,
-
1
],
[
1
,
15
,
3
,
4
],
[
1
,
2
,
3
,
4
]],
[[
1
,
15
,
-
1
,
-
1
],
[
1
,
15
,
3
,
4
],
[
1
,
15
,
-
1
,
-
1
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
,
)
...
...
vllm/v1/sample/ops/bad_words.py
View file @
10ef65ed
...
...
@@ -42,11 +42,16 @@ def apply_bad_words_with_drafts(
num_draft_tokens
:
list
[
int
],
)
->
None
:
start_idx
=
0
for
i
,
bad_words_ids
in
bad_words_token_ids
.
items
():
for
draft_idx
in
range
(
num_draft_tokens
[
i
]):
_apply_bad_words_single_batch
(
logits
[
start_idx
+
draft_idx
],
bad_words_ids
,
past_tokens_ids
[
start_idx
+
draft_idx
],
)
start_idx
+=
num_draft_tokens
[
i
]
remaining
=
len
(
bad_words_token_ids
)
for
i
,
n
in
enumerate
(
num_draft_tokens
):
if
(
bad_words_ids
:
=
bad_words_token_ids
.
get
(
i
))
is
not
None
:
for
draft_idx
in
range
(
start_idx
,
start_idx
+
n
):
_apply_bad_words_single_batch
(
logits
[
draft_idx
],
bad_words_ids
,
past_tokens_ids
[
draft_idx
],
)
remaining
-=
1
if
not
remaining
:
break
start_idx
+=
n
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment