Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fe472b1d
Unverified
Commit
fe472b1d
authored
Nov 14, 2023
by
Joao Gante
Committed by
GitHub
Nov 14, 2023
Browse files
Generate: fix `ExponentialDecayLengthPenalty` doctest (#27485)
fix exponential doctest
parent
73bc0c9e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
12 deletions
+19
-12
src/transformers/generation/logits_process.py
src/transformers/generation/logits_process.py
+19
-12
No files found.
src/transformers/generation/logits_process.py
View file @
fe472b1d
...
@@ -1327,22 +1327,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
...
@@ -1327,22 +1327,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
```python
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> text = "Just wanted to let you know, I"
>>> text = "Just wanted to let you know, I"
>>> inputs = tokenizer(text, return_tensors="pt")
>>> inputs = tokenizer(text, return_tensors="pt")
>>> # Generate sequences without exponential penalty. We want short sentences, so we limit max_length=30
>>> # Let's consider that we want short sentences, so we limit `max_length=30`. However, we observe that the answer
>>> # see that the answer tends to end abruptly
>>> # tends to end abruptly.
>>> set_seed(1)
>>> outputs = model.generate(**inputs, do_sample=True, temperature=0.9, max_length=30, pad_token_id=50256)
>>> outputs = model.generate(**inputs, do_sample=True, temperature=0.9, max_length=30, pad_token_id=50256)
>>> print(tokenizer.batch_decode(outputs)[0])
>>> print(tokenizer.batch_decode(outputs)[0])
Just wanted to let you know, I'm not even a lawyer. I'm a man. I have no real knowledge of politics. I'm a
Just wanted to let you know, I received a link to an ebook, the book How To Start A Social Network which was
published in 2010. Although
>>> # Generate sequences with exponential penalty, we add the exponential_decay_length_penalty=(start_index, decay_factor)
>>> # We see that instead of cutting at max_tokens, the output comes to an end before (at 25 tokens) and with more meaning
>>> # To promote the appearance of the EOS token at the right time, we add the `exponential_decay_length_penalty =
>>> # What happens is that starting from `start_index` the EOS token score will be increased by decay_factor exponentially
>>> # (start_index, decay_factor)`. Instead of cutting at max_tokens, the output comes to an end before and usually
>>> # with more meaning. What happens is that starting from `start_index` the EOS token score will be increased
>>> # by `decay_factor` exponentially. However, if you set a high decay factor, you may also end up with abruptly
>>> # ending sequences.
>>> set_seed(1)
>>> outputs = model.generate(
>>> outputs = model.generate(
... **inputs,
... **inputs,
... do_sample=True,
... do_sample=True,
...
@@ -1352,19 +1356,22 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
...
@@ -1352,19 +1356,22 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
... exponential_decay_length_penalty=(15, 1.6),
... exponential_decay_length_penalty=(15, 1.6),
... )
... )
>>> print(tokenizer.batch_decode(outputs)[0])
>>> print(tokenizer.batch_decode(outputs)[0])
Just wanted to let you know, I've got a very cool t-shirt educating people on how to use the Internet<|endoftext|>
Just wanted to let you know, I received a link to an ebook, the book How To Start A Social Network
which<|endoftext|>
>>> # Generate sequences with smaller decay_factor, still improving the hard cutoff mid-sentence
>>> # With a small decay factor, you will have a higher chance of getting a meaningful sequence.
>>> set_seed(1)
>>> outputs = model.generate(
>>> outputs = model.generate(
... **inputs,
... **inputs,
... do_sample=True,
... do_sample=True,
... temperature=0.9,
... temperature=0.9,
... max_length=30,
... max_length=30,
... pad_token_id=50256,
... pad_token_id=50256,
... exponential_decay_length_penalty=(15, 1.0
5
),
... exponential_decay_length_penalty=(15, 1.0
1
),
... )
... )
>>> print(tokenizer.batch_decode(outputs)[0])
>>> print(tokenizer.batch_decode(outputs)[0])
Just wanted to let you know, I've been working on it for about 6 months and now it's in Alpha.<|endoftext|>
Just wanted to let you know, I received a link to an ebook, the book How To Start A Social Network which was
published in 2010.<|endoftext|>
```
```
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment