Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
82e360b7
Unverified
Commit
82e360b7
authored
Oct 18, 2022
by
Yuta Koreeda
Committed by
GitHub
Oct 17, 2022
Browse files
Fixed the docstring and type hint for forced_decoder_ids option in Ge… (#19640)
parent
f2ecb9ee
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
18 deletions
+24
-18
src/transformers/generation_logits_process.py
src/transformers/generation_logits_process.py
+4
-3
src/transformers/generation_tf_logits_process.py
src/transformers/generation_tf_logits_process.py
+4
-3
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+10
-7
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+6
-5
No files found.
src/transformers/generation_logits_process.py
View file @
82e360b7
...
...
@@ -735,10 +735,11 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
class
ForceTokensLogitsProcessor
(
LogitsProcessor
):
r
"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they
are sampled at their corresponding index."""
r
"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index."""
def
__init__
(
self
,
force_token_map
):
def
__init__
(
self
,
force_token_map
:
List
[
List
[
int
]]
):
self
.
force_token_map
=
dict
(
force_token_map
)
def
__call__
(
self
,
input_ids
,
scores
):
...
...
src/transformers/generation_tf_logits_process.py
View file @
82e360b7
...
...
@@ -547,10 +547,11 @@ class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
class
TFForceTokensLogitsProcessor
(
TFLogitsProcessor
):
r
"""This processor can be used to force a list of tokens. The processor will set their log probs to `0` and all
other tokens to `-inf` so that they are sampled at their corresponding index."""
r
"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
`-inf` so that they are sampled at their corresponding index."""
def
__init__
(
self
,
force_token_map
):
def
__init__
(
self
,
force_token_map
:
List
[
List
[
int
]]
):
force_token_map
=
dict
(
force_token_map
)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
...
...
src/transformers/generation_tf_utils.py
View file @
82e360b7
...
...
@@ -406,7 +406,7 @@ class TFGenerationMixin:
forced_eos_token_id
=
None
,
suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
begin_suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
List
[
int
]]
]
=
None
,
**
model_kwargs
,
)
->
Union
[
TFGreedySearchOutput
,
TFSampleOutput
,
TFBeamSearchOutput
,
TFBeamSampleOutput
,
tf
.
Tensor
]:
r
"""
...
...
@@ -506,8 +506,10 @@ class TFGenerationMixin:
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens, before sampling.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
...
...
@@ -1493,9 +1495,10 @@ class TFGenerationMixin:
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model.
...
...
@@ -2147,7 +2150,7 @@ class TFGenerationMixin:
forced_eos_token_id
:
int
,
suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
begin_suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
List
[
int
]]
]
=
None
,
)
->
TFLogitsProcessorList
:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
...
...
src/transformers/generation_utils.py
View file @
82e360b7
...
...
@@ -696,7 +696,7 @@ class GenerationMixin:
renormalize_logits
:
Optional
[
bool
],
suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
begin_suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
List
[
int
]]
]
=
None
,
)
->
LogitsProcessorList
:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
...
...
@@ -956,7 +956,7 @@ class GenerationMixin:
exponential_decay_length_penalty
:
Optional
[
Tuple
[
int
,
float
]]
=
None
,
suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
begin_suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
List
[
int
]]
]
=
None
,
**
model_kwargs
,
)
->
Union
[
GreedySearchOutput
,
SampleOutput
,
BeamSearchOutput
,
BeamSampleOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -1121,9 +1121,10 @@ class GenerationMixin:
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens, before sampling.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment