Unverified Commit df55c2b9 authored by Will Frey's avatar Will Frey Committed by GitHub
Browse files

Update typing in generation_logits_process.py (#12901)

While `Iterable[Iterable[int]]` is a nicer annotation (it's covariant!), the defensive statements parsing out `bad_words_ids` in `__init__(...)` force the caller to pass in `List[List[int]]`. I've changed the annotation to make that clear.
parent c164064e
......@@ -354,7 +354,7 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
The id of the `end-of-sequence` token.
"""
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment