Unverified Commit cde0c750 authored by David del Río Medina's avatar David del Río Medina Committed by GitHub
Browse files

Replace assertions with ValueError exceptions (#14018)

* Replace assertions with ValueError exceptions

* Change length check for a more explicit one
parent 968ae57c
......@@ -85,9 +85,11 @@ class LogitsProcessorList(list):
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
assert all(
arg in kwargs for arg in list(function_args.keys())[2:]
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
......@@ -381,7 +383,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
self.static_bad_words_mask: Optional[torch.LongTensor] = None
for banned_token_seq in self.bad_words_id_length_greater_than_1:
assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list"
if len(banned_token_seq) == 0:
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment