Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc53fc62
Unverified
Commit
bc53fc62
authored
Jan 05, 2023
by
Joao Gante
Committed by
GitHub
Jan 05, 2023
Browse files
Generate: FLAX uses `GenerationConfig` as the basis for `.generate()` parametrization (#21007)
parent
4f1c9d16
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
218 additions
and
178 deletions
+218
-178
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+183
-177
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+35
-1
No files found.
src/transformers/generation/flax_utils.py
View file @
bc53fc62
...
...
@@ -15,6 +15,7 @@
# limitations under the License.
import
copy
import
inspect
import
warnings
from
functools
import
partial
...
...
@@ -33,6 +34,7 @@ from ..models.auto import (
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..utils
import
ModelOutput
,
logging
from
.configuration_utils
import
GenerationConfig
from
.flax_logits_process
import
(
FlaxForcedBOSTokenLogitsProcessor
,
FlaxForcedEOSTokenLogitsProcessor
,
...
...
@@ -136,6 +138,11 @@ class FlaxGenerationMixin:
`do_sample=False`.
"""
def
prepare_inputs_for_generation
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
@
staticmethod
def
_run_loop_in_debug
(
cond_fn
,
body_fn
,
init_state
):
"""
...
...
@@ -171,7 +178,7 @@ class FlaxGenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if
not
hasattr
(
self
,
"prepare_inputs_for
_generat
ion"
):
if
not
self
.
can
_generat
e
(
):
generate_compatible_mappings
=
[
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
...
...
@@ -211,27 +218,11 @@ class FlaxGenerationMixin:
def
generate
(
self
,
input_ids
:
jnp
.
ndarray
,
max_length
:
Optional
[
int
]
=
None
,
max_new_tokens
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
decoder_start_token_id
:
Optional
[
int
]
=
None
,
do_sample
:
Optional
[
bool
]
=
None
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
prng_key
:
Optional
[
jnp
.
ndarray
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
num_beams
:
Optional
[
int
]
=
None
,
no_repeat_ngram_size
:
Optional
[
int
]
=
None
,
min_length
:
Optional
[
int
]
=
None
,
forced_bos_token_id
:
Optional
[
int
]
=
None
,
forced_eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
trace
:
bool
=
True
,
params
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
**
model_
kwargs
,
**
kwargs
,
):
r
"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
...
...
@@ -246,100 +237,151 @@ class FlaxGenerationMixin:
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model
.
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`
.
</Tip>
For a complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
</Tip>
Parameters:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
the prompt.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
trace (`bool`, *optional*, defaults to `True`):
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
considerably slower runtime.
params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_
kwargs:
Additional model
specific kwargs will be
forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
should be prefixed
with *
decoder
_*. Also accepts `encoder_outputs` to skip encoder part
.
kwargs:
Ad
hoc parametrization of `generate_config` and/or ad
ditional model
-
specific kwargs
that
will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs
should
not
be prefixed
and
decoder
specific kwargs should be prefixed with *decoder_*
.
Return:
[`~utils.ModelOutput`].
Examples:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> inputs = tokenizer(input_context, return_tensors="np")
>>> # generate candidates using sampling
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial sampling, modifying an existing generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, GenerationConfig
>>> import numpy as np
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # Sample up to 30 tokens
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
>>> generation_config.do_sample = True
>>> outputs = model.generate(
... input_ids, generation_config=generation_config, prng_key=np.asarray([0, 0], dtype=np.uint32)
... )
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get a change in that system. The way I saw it was this: a few years ago, this company would not']
```
Beam-search decoding, using a freshly initialized generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="np").input_ids
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
#
V
alidate the `.generate()` call
#
Handle `generation_config` and kwargs that might update it, and v
alidate the `.generate()` call
self
.
_validate_model_class
()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
)
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
if
generation_config
.
pad_token_id
is
None
and
generation_config
.
eos_token_id
is
not
None
:
if
model_kwargs
.
get
(
"attention_mask"
)
is
None
:
logger
.
warning
(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id
=
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
eos_token_id
[
0
]
logger
.
warning
(
f
"Setting `pad_token_id` to `eos_token_id`:
{
eos_token_id
}
for open-end generation."
)
pad_token_id
=
eos_token_id
generation_config
.
pad_token_id
=
eos_token_id
if
decoder_start_token_id
is
None
and
self
.
config
.
is_encoder_decoder
:
if
generation_config
.
decoder_start_token_id
is
None
and
self
.
config
.
is_encoder_decoder
:
raise
ValueError
(
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
)
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
if
not
self
.
config
.
is_encoder_decoder
and
not
trace
:
if
pad_token_id
is
not
None
and
jnp
.
sum
(
input_ids
[:,
-
1
]
==
pad_token_id
)
>
0
:
if
(
generation_config
.
pad_token_id
is
not
None
and
jnp
.
sum
(
input_ids
[:,
-
1
]
==
generation_config
.
pad_token_id
)
>
0
):
logger
.
warning
(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
...
...
@@ -350,71 +392,62 @@ class FlaxGenerationMixin:
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
params
,
model_kwargs
)
# prepare decoder_input_ids for generation
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
decoder_start_token_id
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
generation_config
.
decoder_start_token_id
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
if
max_length
is
None
and
max_new_tokens
is
None
:
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to
"
f
"
{
self
.
config
.
max_length
}
(`
self.
config.max_length`). Controlling `max_length` via the
config is
"
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we
recommend
"
"using `max_new_tokens` to control the maximum length of the generation."
,
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f
"
{
generation_
config
.
max_length
}
(`
generation_
config.max_length`). Controlling `max_length` via the"
"
config is
deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
"
recommend
using `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
)
elif
max_length
is
None
and
max_new_tokens
is
not
None
:
max_length
=
max_new_tokens
+
input_ids_seq_length
elif
max_length
is
not
None
and
max_new_tokens
is
not
None
:
elif
has_default_max_length
and
generation_config
.
max_new_tokens
is
not
None
:
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
elif
not
has_default_max_length
and
generation_config
.
max_new_tokens
is
not
None
:
raise
ValueError
(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
# default to config if still None
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
if
min_length
is
not
None
and
min_length
>
max_length
:
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
f
"Unfeasable length constraints: the minimum length (
{
min_length
}
) is larger than
the maximum
"
f
"
length (
{
max_length
}
)"
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger than"
f
"
the maximum length (
{
generation_config
.
max_length
}
)"
)
if
input_ids_seq_length
>=
max_length
:
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
max_length
}
. This can lead to unexpected behavior. You should consider
increasing
"
"`max_new_tokens`."
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
"
increasing
`max_new_tokens`."
)
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
)
if
not
do_sample
and
num_beams
==
1
:
logits_processor
=
self
.
_get_logits_processor
(
no_repeat_ngram_size
,
min_length
,
max_length
,
eos_token_id
,
forced_bos_token_id
,
forced_eos_token_id
)
if
not
generation_config
.
do_sample
and
generation_config
.
num_beams
==
1
:
return
self
.
_greedy_search
(
input_ids
,
max_length
,
pad_token_id
,
eos_token_id
,
generation_config
.
max_length
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
,
logits_processor
=
logits_processor
,
trace
=
trace
,
params
=
params
,
model_kwargs
=
model_kwargs
,
)
elif
do_sample
and
num_beams
==
1
:
logits_warper
=
self
.
_get_logits_warper
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
logits_processor
=
self
.
_get_logits_processor
(
no_repeat_ngram_size
,
min_length
,
max_length
,
eos_token_id
,
forced_bos_token_id
,
forced_eos_token_id
)
elif
generation_config
.
do_sample
and
generation_config
.
num_beams
==
1
:
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
return
self
.
_sample
(
input_ids
,
max_length
,
pad_token_id
,
eos_token_id
,
generation_config
.
max_length
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
,
prng_key
,
logits_warper
=
logits_warper
,
logits_processor
=
logits_processor
,
...
...
@@ -422,31 +455,27 @@ class FlaxGenerationMixin:
params
=
params
,
model_kwargs
=
model_kwargs
,
)
elif
not
do_sample
and
num_beams
>
1
:
elif
not
generation_config
.
do_sample
and
generation_config
.
num_beams
>
1
:
# broadcast input_ids & encoder_outputs
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
num_beams
)
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
if
"encoder_outputs"
in
model_kwargs
:
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
num_beams
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
)
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
num_beams
)
logits_processor
=
self
.
_get_logits_processor
(
no_repeat_ngram_size
,
min_length
,
max_length
,
eos_token_id
,
forced_bos_token_id
,
forced_eos_token_id
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
return
self
.
_beam_search
(
input_ids
,
max_length
,
pad_token_id
,
eos_token_id
,
length_penalty
=
length_penalty
,
early_stopping
=
early_stopping
,
generation_config
.
max_length
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
,
length_penalty
=
generation_config
.
length_penalty
,
early_stopping
=
generation_config
.
early_stopping
,
logits_processor
=
logits_processor
,
trace
=
trace
,
params
=
params
,
...
...
@@ -455,67 +484,44 @@ class FlaxGenerationMixin:
else
:
raise
NotImplementedError
(
"`Beam sampling is currently not implemented."
)
def
_get_logits_warper
(
self
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
)
->
FlaxLogitsProcessorList
:
def
_get_logits_warper
(
self
,
generation_config
:
GenerationConfig
)
->
FlaxLogitsProcessorList
:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
instances used for multinomial sampling.
"""
# init warp parameters
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
top_p
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
# instantiate warpers list
warpers
=
FlaxLogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if
temperature
is
not
None
and
temperature
!=
1.0
:
warpers
.
append
(
FlaxTemperatureLogitsWarper
(
temperature
))
if
top_k
is
not
None
and
top_k
!=
0
:
warpers
.
append
(
FlaxTopKLogitsWarper
(
top_k
=
top_k
,
min_tokens_to_keep
=
1
))
if
top_p
is
not
None
and
top_p
<
1.0
:
warpers
.
append
(
FlaxTopPLogitsWarper
(
top_p
=
top_p
,
min_tokens_to_keep
=
1
))
if
generation_config
.
temperature
is
not
None
and
generation_config
.
temperature
!=
1.0
:
warpers
.
append
(
FlaxTemperatureLogitsWarper
(
generation_config
.
temperature
))
if
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
!=
0
:
warpers
.
append
(
FlaxTopKLogitsWarper
(
top_k
=
generation_config
.
top_k
,
min_tokens_to_keep
=
1
))
if
generation_config
.
top_p
is
not
None
and
generation_config
.
top_p
<
1.0
:
warpers
.
append
(
FlaxTopPLogitsWarper
(
top_p
=
generation_config
.
top_p
,
min_tokens_to_keep
=
1
))
return
warpers
def
_get_logits_processor
(
self
,
no_repeat_ngram_size
:
int
,
min_length
:
int
,
max_length
:
int
,
eos_token_id
:
int
,
forced_bos_token_id
:
int
,
forced_eos_token_id
:
int
,
)
->
FlaxLogitsProcessorList
:
def
_get_logits_processor
(
self
,
generation_config
:
GenerationConfig
)
->
FlaxLogitsProcessorList
:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
instances used to modify the scores of the language model head.
"""
processors
=
FlaxLogitsProcessorList
()
# init warp parameters
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_siz
e
)
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
forced_bos_token_id
=
(
forced_bos_token_id
if
forced_bos_token_id
is
not
None
else
self
.
config
.
forced_b
os_token_id
if
(
generation_config
.
min_length
is
not
None
and
generation_config
.
eos_token_id
is
not
Non
e
and
generation_config
.
min_length
>
-
1
):
processors
.
append
(
FlaxMinLengthLogitsProcessor
(
generation_config
.
min_length
,
generation_config
.
e
os_token_id
)
)
forced_eos_token_id
=
(
forced_eos_token_id
if
forced_eos_token_id
is
not
None
else
self
.
config
.
forced_eos_token_id
if
generation_config
.
forced_bos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedBOSTokenLogitsProcessor
(
generation_config
.
forced_bos_token_id
))
if
generation_config
.
forced_eos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedEOSTokenLogitsProcessor
(
generation_config
.
max_length
,
generation_config
.
forced_eos_token_id
)
)
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
-
1
:
processors
.
append
(
FlaxMinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
if
forced_bos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedBOSTokenLogitsProcessor
(
forced_bos_token_id
))
if
forced_eos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedEOSTokenLogitsProcessor
(
max_length
,
forced_eos_token_id
))
return
processors
def
_greedy_search
(
...
...
@@ -530,9 +536,9 @@ class FlaxGenerationMixin:
model_kwargs
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
):
# init values
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
batch_size
,
cur_len
=
input_ids
.
shape
...
...
@@ -618,9 +624,9 @@ class FlaxGenerationMixin:
model_kwargs
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
):
# init values
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
batch_size
,
cur_len
=
input_ids
.
shape
...
...
@@ -716,7 +722,7 @@ class FlaxGenerationMixin:
):
"""
This beam search function is heavily inspired by Flax's official example:
https://github.com/google/flax/blob/ma
ster
/examples/wmt/
train.py#L254
https://github.com/google/flax/blob/ma
in
/examples/wmt/
decode.py
"""
def
flatten_beam_dim
(
tensor
):
...
...
@@ -751,11 +757,11 @@ class FlaxGenerationMixin:
return
jax
.
tree_util
.
tree_map
(
gather_fn
,
nested
)
# init values
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
generation_
config
.
length_penalty
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
generation_
config
.
early_stopping
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
...
...
src/transformers/modeling_flax_utils.py
View file @
bc53fc62
...
...
@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from
.configuration_utils
import
PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.generation
import
FlaxGenerationMixin
from
.generation
import
FlaxGenerationMixin
,
GenerationConfig
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
(
FLAX_WEIGHTS_INDEX_NAME
,
...
...
@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self
.
key
=
PRNGKey
(
seed
)
self
.
dtype
=
dtype
self
.
input_shape
=
input_shape
self
.
generation_config
=
GenerationConfig
.
from_model_config
(
config
)
if
self
.
can_generate
()
else
None
# To check if the model was intialized automatically.
self
.
_is_initialized
=
_do_init
...
...
@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
def
can_generate
(
self
)
->
bool
:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
):
return
False
return
True
@
classmethod
def
from_pretrained
(
cls
,
...
...
@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# If it is a model with generation capabilities, attempt to load the generation config
if
model
.
can_generate
():
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
subfolder
=
subfolder
,
_from_auto
=
from_auto_class
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
)
except
OSError
:
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
)
pass
if
_do_init
:
# set correct parameters
model
.
params
=
unflatten_dict
(
state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment