Unverified Commit cde0c750 authored by David del Río Medina's avatar David del Río Medina Committed by GitHub
Browse files

Replace assertions with ValueError exceptions (#14018)

* Replace assertions with ValueError exceptions

* Change length check for a more explicit one
parent 968ae57c
...@@ -85,9 +85,11 @@ class LogitsProcessorList(list): ...@@ -85,9 +85,11 @@ class LogitsProcessorList(list):
for processor in self: for processor in self:
function_args = inspect.signature(processor.__call__).parameters function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2: if len(function_args) > 2:
assert all( if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
arg in kwargs for arg in list(function_args.keys())[2:] raise ValueError(
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs) scores = processor(input_ids, scores, **kwargs)
else: else:
scores = processor(input_ids, scores) scores = processor(input_ids, scores)
...@@ -381,7 +383,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): ...@@ -381,7 +383,8 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
self.static_bad_words_mask: Optional[torch.LongTensor] = None self.static_bad_words_mask: Optional[torch.LongTensor] = None
for banned_token_seq in self.bad_words_id_length_greater_than_1: for banned_token_seq in self.bad_words_id_length_greater_than_1:
assert len(banned_token_seq) > 0, f"Banned words token sequences {bad_words_ids} cannot have an empty list" if len(banned_token_seq) == 0:
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0: if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment