Unverified Commit c2ed069b authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix mixed penalties batch with async scheduling (#27910)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent af6e19f5
......@@ -21,6 +21,14 @@ def apply_all_penalties(
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
# In the async scheduling case, rows that won't have penalties applied may contain
# -1 placeholder token ids. We must replace these with valid token ids so that the
# scatter done in apply_penalties is valid.
# NOTE(nick): The penalties implementation is currently quite inefficient and
# will be reworked anyhow.
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
return apply_penalties(
logits,
prompt_token_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment