Unverified Commit 10ef65ed authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix bad words with speculative decoding (#31908)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 6170d47d
...@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler): ...@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler):
def test_bad_words(rejection_sampler): def test_bad_words(rejection_sampler):
"""Test rejection sampling with bad words constraints""" """Test rejection sampling with bad words constraints.
This test applies bad words to non-consecutive requests (0 and 2, but not 1)
to verify correct logit indexing when iterating over requests with bad words.
"""
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]] spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]]
logits = create_logits_tensor(output_tokens, token_idx_to_override=15) logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata( metadata = create_sampling_metadata(
...@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler): ...@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler):
output_token_ids=[[2], [3], [4]], output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens, spec_token_ids=spec_tokens,
bad_words_token_ids={ bad_words_token_ids={
0: [ 0: [[2]],
[ # Request 1 has no bad words (to test non-consecutive request handling)
2, 2: [[2]],
]
],
1: [
[
2,
]
],
# Do not apply bad words to the last request
}, },
) )
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
...@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler): ...@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler):
sampling_metadata=metadata, sampling_metadata=metadata,
) )
# Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
# Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
# Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
expected = torch.tensor( expected = torch.tensor(
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]], [[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]],
dtype=torch.int, dtype=torch.int,
device=logits.device, device=logits.device,
) )
......
...@@ -42,11 +42,16 @@ def apply_bad_words_with_drafts( ...@@ -42,11 +42,16 @@ def apply_bad_words_with_drafts(
num_draft_tokens: list[int], num_draft_tokens: list[int],
) -> None: ) -> None:
start_idx = 0 start_idx = 0
for i, bad_words_ids in bad_words_token_ids.items(): remaining = len(bad_words_token_ids)
for draft_idx in range(num_draft_tokens[i]): for i, n in enumerate(num_draft_tokens):
_apply_bad_words_single_batch( if (bad_words_ids := bad_words_token_ids.get(i)) is not None:
logits[start_idx + draft_idx], for draft_idx in range(start_idx, start_idx + n):
bad_words_ids, _apply_bad_words_single_batch(
past_tokens_ids[start_idx + draft_idx], logits[draft_idx],
) bad_words_ids,
start_idx += num_draft_tokens[i] past_tokens_ids[draft_idx],
)
remaining -= 1
if not remaining:
break
start_idx += n
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment