Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8ac4d03
Unverified
Commit
b8ac4d03
authored
May 01, 2024
by
Raushan Turganbay
Committed by
GitHub
Apr 30, 2024
Browse files
Fix generation doctests (#30263)
* fix doctest * fix torch doctest * make CI happy * raise error * make fixup
parent
2ecefc39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
15 deletions
+16
-15
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+10
-3
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+3
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+3
-9
No files found.
src/transformers/generation/candidate_generator.py
View file @
b8ac4d03
...
@@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
...
@@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import
torch
import
torch
from
..cache_utils
import
DynamicCache
from
..cache_utils
import
DynamicCache
from
.logits_process
import
LogitsProcessorList
,
MinLengthLogitsProcessor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
..modeling_utils
import
PreTrainedModel
from
..modeling_utils
import
PreTrainedModel
from
.configuration_utils
import
GenerationConfig
from
.configuration_utils
import
GenerationConfig
from
.logits_process
import
LogitsProcessorList
class
CandidateGenerator
:
class
CandidateGenerator
:
...
@@ -94,9 +94,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -94,9 +94,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
assistant_model
:
"PreTrainedModel"
,
assistant_model
:
"PreTrainedModel"
,
generation_config
:
"GenerationConfig"
,
generation_config
:
"GenerationConfig"
,
logits_processor
:
"LogitsProcessorList"
,
model_kwargs
:
Dict
,
model_kwargs
:
Dict
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_processor
:
"LogitsProcessorList"
=
None
,
):
):
# Make sure all data at the same device as assistant model
# Make sure all data at the same device as assistant model
device
=
assistant_model
.
device
device
=
assistant_model
.
device
...
@@ -145,15 +145,22 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -145,15 +145,22 @@ class AssistedCandidateGenerator(CandidateGenerator):
self
.
input_ids_key
=
"input_ids"
self
.
input_ids_key
=
"input_ids"
# Prepare generation-related options.
# Prepare generation-related options.
self
.
logits_processor
=
logits_processor
self
.
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
self
.
generation_config
=
copy
.
deepcopy
(
generation_config
)
self
.
generation_config
=
copy
.
deepcopy
(
generation_config
)
self
.
generation_config
.
return_dict_in_generate
=
True
self
.
generation_config
.
return_dict_in_generate
=
True
self
.
generation_config
.
output_scores
=
True
self
.
generation_config
.
output_scores
=
True
# avoid unnecessary warnings that min_length is larger than max_new_tokens
# avoid unnecessary warnings that min_length is larger than max_new_tokens
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
self
.
main_model_min_length
=
self
.
generation_config
.
min_length
self
.
main_model_min_length
=
self
.
generation_config
.
min_length
self
.
generation_config
.
min_length
=
0
self
.
generation_config
.
min_length
=
0
self
.
generation_config
.
min_new_tokens
=
None
self
.
generation_config
.
min_new_tokens
=
None
for
processor
in
self
.
logits_processor
:
if
type
(
processor
)
==
MinLengthLogitsProcessor
:
raise
ValueError
(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
)
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
"""
"""
...
...
src/transformers/generation/tf_utils.py
View file @
b8ac4d03
...
@@ -528,9 +528,9 @@ class TFGenerationMixin:
...
@@ -528,9 +528,9 @@ class TFGenerationMixin:
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | logits | probability
... # | token | token string | logits | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.41
3
| 24.33%
| 262 | the | -1.41
4
| 24.33%
| 1110 | day | -2.609 | 7.36%
| 1110 | day | -2.609 | 7.36%
| 618 | when | -2.00
9
| 13.4
1
%
| 618 | when | -2.0
1
0 | 13.4
0
%
| 356 | we | -1.859 | 15.58%
| 356 | we | -1.859 | 15.58%
| 460 | can | -2.508 | 8.14%
| 460 | can | -2.508 | 8.14%
...
@@ -549,7 +549,7 @@ class TFGenerationMixin:
...
@@ -549,7 +549,7 @@ class TFGenerationMixin:
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
>>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
>>> # use case, you might want to recompute it with `normalize_logits=True`.
>>> # use case, you might want to recompute it with `normalize_logits=True`.
>>> output_length =
input_length +
np.sum(transition_scores.numpy() < 0, axis=1)
>>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty
>>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty)
>>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
...
...
src/transformers/generation/utils.py
View file @
b8ac4d03
...
@@ -705,9 +705,9 @@ class GenerationMixin:
...
@@ -705,9 +705,9 @@ class GenerationMixin:
input_ids
=
input_ids
,
input_ids
=
input_ids
,
assistant_model
=
assistant_model
,
assistant_model
=
assistant_model
,
generation_config
=
generation_config
,
generation_config
=
generation_config
,
logits_processor
=
logits_processor
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
inputs_tensor
=
inputs_tensor
,
inputs_tensor
=
inputs_tensor
,
logits_processor
=
logits_processor
,
)
)
return
candidate_generator
return
candidate_generator
...
@@ -4601,24 +4601,18 @@ class GenerationMixin:
...
@@ -4601,24 +4601,18 @@ class GenerationMixin:
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
>>> input_prompt = "It might be possible to"
>>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> model.generation_config.min_length = 10
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> candidate_generator = AssistedCandidateGenerator(
>>> candidate_generator = AssistedCandidateGenerator(
... input_ids=input_ids,
... input_ids=input_ids,
... assistant_model=assistant_model,
... assistant_model=assistant_model,
... generation_config=model.generation_config,
... generation_config=model.generation_config,
... logits_processor=logits_processor,
... model_kwargs={},
... model_kwargs={},
... )
... )
>>> outputs = model._assisted_decoding(
>>> outputs = model._assisted_decoding(
... input_ids,
... input_ids,
... candidate_generator=candidate_generator,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria,
... stopping_criteria=stopping_criteria,
... )
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment