"docs/source/vscode:/vscode.git/clone" did not exist on "7fc80724dace3e27fd9d540228fa4d95fbf94970"
Unverified Commit 8f5d62fd authored by Ngo Quang Huy's avatar Ngo Quang Huy Committed by GitHub
Browse files

Fix `bad_words_ids` not working with sentencepiece-based tokenizers (#15343)



* Fix `bad_word_ids` not working with sentencepiece-based tokenizers

* make style
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 06107541
...@@ -896,7 +896,8 @@ class GenerationMixin: ...@@ -896,7 +896,8 @@ class GenerationMixin:
`decoder_input_ids`. `decoder_input_ids`.
bad_words_ids(`List[List[int]]`, *optional*): bad_words_ids(`List[List[int]]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids`. should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True,
add_special_tokens=False).input_ids`.
num_return_sequences(`int`, *optional*, defaults to 1): num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. The number of independently computed returned sequences for each element in the batch.
max_time(`float`, *optional*, defaults to None): max_time(`float`, *optional*, defaults to None):
...@@ -1026,7 +1027,9 @@ class GenerationMixin: ...@@ -1026,7 +1027,9 @@ class GenerationMixin:
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "My cute dog" >>> input_context = "My cute dog"
>>> # get tokens of words that should not be generated >>> # get tokens of words that should not be generated
>>> bad_words_ids = tokenizer(["idiot", "stupid", "shut up"], add_prefix_space=True).input_ids >>> bad_words_ids = tokenizer(
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
>>> ).input_ids
>>> # encode input context >>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate sequences without allowing bad_words to be generated >>> # generate sequences without allowing bad_words to be generated
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment