"tests/models/vscode:/vscode.git/clone" did not exist on "50378cbf6c1fd8717a74b36c352f57f9a73e7282"
Unverified Commit 3c12e3c1 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix overflowing bad word ids (#10889)

* Removes overflowing bad word IDs

* Raise warning
parent 1f5ea9e0
......@@ -22,6 +22,10 @@ import numpy as np
import torch
from .file_utils import add_start_docstrings
from .utils.logging import get_logger
logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
......@@ -417,7 +421,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
# Eliminates invalid bad word IDs that are over the vocabulary size.
if token <= scores.shape[1]:
banned_mask_list.append([idx, token])
else:
logger.error(
f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
f"vocabulary, and is therefore ignored."
)
if not banned_mask_list:
return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment