Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a767276f
Unverified
Commit
a767276f
authored
Nov 02, 2021
by
Li-Huai (Allan) Lin
Committed by
GitHub
Nov 02, 2021
Browse files
Fix generation docstring (#14216)
* Fix generation docstring * Style
parent
e20faa6f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+2
-2
src/transformers/models/gpt2/tokenization_gpt2_fast.py
src/transformers/models/gpt2/tokenization_gpt2_fast.py
+2
-2
No files found.
src/transformers/generation_utils.py
View file @
a767276f
...
@@ -849,11 +849,11 @@ class GenerationMixin:
...
@@ -849,11 +849,11 @@ class GenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2"
, use_fast=False
)
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "My cute dog"
>>> input_context = "My cute dog"
>>> # get tokens of words that should not be generated
>>> # get tokens of words that should not be generated
>>> bad_words_ids =
[
tokenizer(
bad_word
, add_prefix_space=True).input_ids
for bad_word in ["idiot", "stupid", "shut up"]]
>>> bad_words_ids = tokenizer(
["idiot", "stupid", "shut up"]
, add_prefix_space=True).input_ids
>>> # encode input context
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
>>> # generate sequences without allowing bad_words to be generated
>>> # generate sequences without allowing bad_words to be generated
...
...
src/transformers/models/gpt2/tokenization_gpt2_fast.py
View file @
a767276f
...
@@ -84,8 +84,8 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
...
@@ -84,8 +84,8 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
>>> tokenizer(" Hello world")['input_ids']
>>> tokenizer(" Hello world")['input_ids']
[18435, 995]
[18435, 995]
You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer
or when you
You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer
, but since
call it on some text, but since
the model was not pretrained this way, it might yield a decrease in performance.
the model was not pretrained this way, it might yield a decrease in performance.
.. note::
.. note::
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment