Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc53fc62
Unverified
Commit
bc53fc62
authored
Jan 05, 2023
by
Joao Gante
Committed by
GitHub
Jan 05, 2023
Browse files
Generate: FLAX uses `GenerationConfig` as the basis for `.generate()` parametrization (#21007)
parent
4f1c9d16
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
218 additions
and
178 deletions
+218
-178
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+183
-177
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+35
-1
No files found.
src/transformers/generation/flax_utils.py
View file @
bc53fc62
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# limitations under the License.
# limitations under the License.
import
copy
import
inspect
import
inspect
import
warnings
import
warnings
from
functools
import
partial
from
functools
import
partial
...
@@ -33,6 +34,7 @@ from ..models.auto import (
...
@@ -33,6 +34,7 @@ from ..models.auto import (
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
)
from
..utils
import
ModelOutput
,
logging
from
..utils
import
ModelOutput
,
logging
from
.configuration_utils
import
GenerationConfig
from
.flax_logits_process
import
(
from
.flax_logits_process
import
(
FlaxForcedBOSTokenLogitsProcessor
,
FlaxForcedBOSTokenLogitsProcessor
,
FlaxForcedEOSTokenLogitsProcessor
,
FlaxForcedEOSTokenLogitsProcessor
,
...
@@ -136,6 +138,11 @@ class FlaxGenerationMixin:
...
@@ -136,6 +138,11 @@ class FlaxGenerationMixin:
`do_sample=False`.
`do_sample=False`.
"""
"""
def
prepare_inputs_for_generation
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
@
staticmethod
@
staticmethod
def
_run_loop_in_debug
(
cond_fn
,
body_fn
,
init_state
):
def
_run_loop_in_debug
(
cond_fn
,
body_fn
,
init_state
):
"""
"""
...
@@ -171,7 +178,7 @@ class FlaxGenerationMixin:
...
@@ -171,7 +178,7 @@ class FlaxGenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
right class to use.
"""
"""
if
not
hasattr
(
self
,
"prepare_inputs_for
_generat
ion"
):
if
not
self
.
can
_generat
e
(
):
generate_compatible_mappings
=
[
generate_compatible_mappings
=
[
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
,
...
@@ -211,27 +218,11 @@ class FlaxGenerationMixin:
...
@@ -211,27 +218,11 @@ class FlaxGenerationMixin:
def
generate
(
def
generate
(
self
,
self
,
input_ids
:
jnp
.
ndarray
,
input_ids
:
jnp
.
ndarray
,
max_length
:
Optional
[
int
]
=
None
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
max_new_tokens
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
decoder_start_token_id
:
Optional
[
int
]
=
None
,
do_sample
:
Optional
[
bool
]
=
None
,
prng_key
:
Optional
[
jnp
.
ndarray
]
=
None
,
prng_key
:
Optional
[
jnp
.
ndarray
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
num_beams
:
Optional
[
int
]
=
None
,
no_repeat_ngram_size
:
Optional
[
int
]
=
None
,
min_length
:
Optional
[
int
]
=
None
,
forced_bos_token_id
:
Optional
[
int
]
=
None
,
forced_eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
trace
:
bool
=
True
,
trace
:
bool
=
True
,
params
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
params
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
**
model_
kwargs
,
**
kwargs
,
):
):
r
"""
r
"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
Generates sequences of token ids for models with a language modeling head. The method supports the following
...
@@ -246,100 +237,151 @@ class FlaxGenerationMixin:
...
@@ -246,100 +237,151 @@ class FlaxGenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
defined in the model's config (`config.json`) which in turn defaults to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
[`~modeling_utils.PretrainedConfig`] of the model
.
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`
.
</Tip>
For a complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
Most of these parameters are explained in more detail in [this blog
</Tip>
post](https://huggingface.co/blog/how-to-generate).
Parameters:
Parameters:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
generation_config (`~generation.GenerationConfig`, *optional*):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
passed to generate matching the attributes of `generation_config` will override them. If
the prompt.
`generation_config` is not provided, the default will be used, which had the following loading
max_new_tokens (`int`, *optional*):
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
do_sample (`bool`, *optional*, defaults to `False`):
default values, whose documentation should be checked to parameterize generation.
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
trace (`bool`, *optional*, defaults to `True`):
trace (`bool`, *optional*, defaults to `True`):
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
considerably slower runtime.
considerably slower runtime.
params (`Dict[str, jnp.ndarray]`, *optional*):
params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_
kwargs:
kwargs:
Additional model
specific kwargs will be
forwarded to the `forward` function of the model. If the model
Ad
hoc parametrization of `generate_config` and/or ad
ditional model
-
specific kwargs
that
will be
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
should be prefixed
with *
decoder
_*. Also accepts `encoder_outputs` to skip encoder part
.
specific kwargs
should
not
be prefixed
and
decoder
specific kwargs should be prefixed with *decoder_*
.
Return:
Return:
[`~utils.ModelOutput`].
[`~utils.ModelOutput`].
Examples:
Examples:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> prompt = "Today I believe we can finally"
>>> inputs = tokenizer(input_context, return_tensors="np")
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
>>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial sampling, modifying an existing generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, GenerationConfig
>>> import numpy as np
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # Sample up to 30 tokens
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
>>> generation_config.do_sample = True
>>> outputs = model.generate(
... input_ids, generation_config=generation_config, prng_key=np.asarray([0, 0], dtype=np.uint32)
... )
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get a change in that system. The way I saw it was this: a few years ago, this company would not']
```
Beam-search decoding, using a freshly initialized generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="np").input_ids
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
```"""
#
V
alidate the `.generate()` call
#
Handle `generation_config` and kwargs that might update it, and v
alidate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_class
()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# set init values
# set init values
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
)
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
if
generation_config
.
pad_token_id
is
None
and
generation_config
.
eos_token_id
is
not
None
:
if
model_kwargs
.
get
(
"attention_mask"
)
is
None
:
if
model_kwargs
.
get
(
"attention_mask"
)
is
None
:
logger
.
warning
(
logger
.
warning
(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
)
eos_token_id
=
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
list
):
if
isinstance
(
eos_token_id
,
list
):
eos_token_id
=
eos_token_id
[
0
]
eos_token_id
=
eos_token_id
[
0
]
logger
.
warning
(
f
"Setting `pad_token_id` to `eos_token_id`:
{
eos_token_id
}
for open-end generation."
)
logger
.
warning
(
f
"Setting `pad_token_id` to `eos_token_id`:
{
eos_token_id
}
for open-end generation."
)
pad_token_id
=
eos_token_id
generation_config
.
pad_token_id
=
eos_token_id
if
decoder_start_token_id
is
None
and
self
.
config
.
is_encoder_decoder
:
if
generation_config
.
decoder_start_token_id
is
None
and
self
.
config
.
is_encoder_decoder
:
raise
ValueError
(
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
)
raise
ValueError
(
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
)
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
if
not
self
.
config
.
is_encoder_decoder
and
not
trace
:
if
not
self
.
config
.
is_encoder_decoder
and
not
trace
:
if
pad_token_id
is
not
None
and
jnp
.
sum
(
input_ids
[:,
-
1
]
==
pad_token_id
)
>
0
:
if
(
generation_config
.
pad_token_id
is
not
None
and
jnp
.
sum
(
input_ids
[:,
-
1
]
==
generation_config
.
pad_token_id
)
>
0
):
logger
.
warning
(
logger
.
warning
(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
"generation results, please set `padding_side='left'` when initializing the tokenizer."
...
@@ -350,71 +392,62 @@ class FlaxGenerationMixin:
...
@@ -350,71 +392,62 @@ class FlaxGenerationMixin:
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
if
model_kwargs
.
get
(
"encoder_outputs"
)
is
None
:
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
params
,
model_kwargs
)
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
params
,
model_kwargs
)
# prepare decoder_input_ids for generation
# prepare decoder_input_ids for generation
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
decoder_start_token_id
input_ids
=
jnp
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
"i4"
)
*
generation_config
.
decoder_start_token_id
# Prepare `max_length` depending on other stopping criteria.
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
if
max_length
is
None
and
max_new_tokens
is
None
:
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
warnings
.
warn
(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to
"
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f
"
{
self
.
config
.
max_length
}
(`
self.
config.max_length`). Controlling `max_length` via the
config is
"
f
"
{
generation_
config
.
max_length
}
(`
generation_
config.max_length`). Controlling `max_length` via the"
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we
recommend
"
"
config is
deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
"using `max_new_tokens` to control the maximum length of the generation."
,
"
recommend
using `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
UserWarning
,
)
)
elif
max_length
is
None
and
max_new_tokens
is
not
None
:
elif
has_default_max_length
and
generation_config
.
max_new_tokens
is
not
None
:
max_length
=
max_new_tokens
+
input_ids_seq_length
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
elif
max_length
is
not
None
and
max_new_tokens
is
not
None
:
elif
not
has_default_max_length
and
generation_config
.
max_new_tokens
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
)
# default to config if still None
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
if
min_length
is
not
None
and
min_length
>
max_length
:
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
raise
ValueError
(
f
"Unfeasable length constraints: the minimum length (
{
min_length
}
) is larger than
the maximum
"
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger than"
f
"
length (
{
max_length
}
)"
f
"
the maximum length (
{
generation_config
.
max_length
}
)"
)
)
if
input_ids_seq_length
>=
max_length
:
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
max_length
}
. This can lead to unexpected behavior. You should consider
increasing
"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
"`max_new_tokens`."
"
increasing
`max_new_tokens`."
)
)
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
)
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
if
not
do_sample
and
num_beams
==
1
:
if
not
generation_config
.
do_sample
and
generation_config
.
num_beams
==
1
:
logits_processor
=
self
.
_get_logits_processor
(
no_repeat_ngram_size
,
min_length
,
max_length
,
eos_token_id
,
forced_bos_token_id
,
forced_eos_token_id
)
return
self
.
_greedy_search
(
return
self
.
_greedy_search
(
input_ids
,
input_ids
,
max_length
,
generation_config
.
max_length
,
pad_token_id
,
generation_config
.
pad_token_id
,
eos_token_id
,
generation_config
.
eos_token_id
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
trace
=
trace
,
trace
=
trace
,
params
=
params
,
params
=
params
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
)
)
elif
do_sample
and
num_beams
==
1
:
elif
generation_config
.
do_sample
and
generation_config
.
num_beams
==
1
:
logits_warper
=
self
.
_get_logits_warper
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
logits_processor
=
self
.
_get_logits_processor
(
no_repeat_ngram_size
,
min_length
,
max_length
,
eos_token_id
,
forced_bos_token_id
,
forced_eos_token_id
)
return
self
.
_sample
(
return
self
.
_sample
(
input_ids
,
input_ids
,
max_length
,
generation_config
.
max_length
,
pad_token_id
,
generation_config
.
pad_token_id
,
eos_token_id
,
generation_config
.
eos_token_id
,
prng_key
,
prng_key
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
...
@@ -422,31 +455,27 @@ class FlaxGenerationMixin:
...
@@ -422,31 +455,27 @@ class FlaxGenerationMixin:
params
=
params
,
params
=
params
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
)
)
elif
not
do_sample
and
num_beams
>
1
:
elif
not
generation_config
.
do_sample
and
generation_config
.
num_beams
>
1
:
# broadcast input_ids & encoder_outputs
# broadcast input_ids & encoder_outputs
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
num_beams
)
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
if
"encoder_outputs"
in
model_kwargs
:
if
"encoder_outputs"
in
model_kwargs
:
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
num_beams
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
)
)
if
"attention_mask"
in
model_kwargs
:
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
num_beams
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
)
logits_processor
=
self
.
_get_logits_processor
(
no_repeat_ngram_size
,
min_length
,
max_length
,
eos_token_id
,
forced_bos_token_id
,
forced_eos_token_id
)
return
self
.
_beam_search
(
return
self
.
_beam_search
(
input_ids
,
input_ids
,
max_length
,
generation_config
.
max_length
,
pad_token_id
,
generation_config
.
pad_token_id
,
eos_token_id
,
generation_config
.
eos_token_id
,
length_penalty
=
length_penalty
,
length_penalty
=
generation_config
.
length_penalty
,
early_stopping
=
early_stopping
,
early_stopping
=
generation_config
.
early_stopping
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
trace
=
trace
,
trace
=
trace
,
params
=
params
,
params
=
params
,
...
@@ -455,67 +484,44 @@ class FlaxGenerationMixin:
...
@@ -455,67 +484,44 @@ class FlaxGenerationMixin:
else
:
else
:
raise
NotImplementedError
(
"`Beam sampling is currently not implemented."
)
raise
NotImplementedError
(
"`Beam sampling is currently not implemented."
)
def
_get_logits_warper
(
def
_get_logits_warper
(
self
,
generation_config
:
GenerationConfig
)
->
FlaxLogitsProcessorList
:
self
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
)
->
FlaxLogitsProcessorList
:
"""
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
instances used for multinomial sampling.
instances used for multinomial sampling.
"""
"""
# init warp parameters
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
top_p
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
# instantiate warpers list
warpers
=
FlaxLogitsProcessorList
()
warpers
=
FlaxLogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
if
generation_config
.
temperature
is
not
None
and
generation_config
.
temperature
!=
1.0
:
# all samplers can be found in `generation_utils_samplers.py`
warpers
.
append
(
FlaxTemperatureLogitsWarper
(
generation_config
.
temperature
))
if
temperature
is
not
None
and
temperature
!=
1.0
:
if
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
!=
0
:
warpers
.
append
(
FlaxTemperatureLogitsWarper
(
temperature
))
warpers
.
append
(
FlaxTopKLogitsWarper
(
top_k
=
generation_config
.
top_k
,
min_tokens_to_keep
=
1
))
if
top_k
is
not
None
and
top_k
!=
0
:
if
generation_config
.
top_p
is
not
None
and
generation_config
.
top_p
<
1.0
:
warpers
.
append
(
FlaxTopKLogitsWarper
(
top_k
=
top_k
,
min_tokens_to_keep
=
1
))
warpers
.
append
(
FlaxTopPLogitsWarper
(
top_p
=
generation_config
.
top_p
,
min_tokens_to_keep
=
1
))
if
top_p
is
not
None
and
top_p
<
1.0
:
warpers
.
append
(
FlaxTopPLogitsWarper
(
top_p
=
top_p
,
min_tokens_to_keep
=
1
))
return
warpers
return
warpers
def
_get_logits_processor
(
def
_get_logits_processor
(
self
,
generation_config
:
GenerationConfig
)
->
FlaxLogitsProcessorList
:
self
,
no_repeat_ngram_size
:
int
,
min_length
:
int
,
max_length
:
int
,
eos_token_id
:
int
,
forced_bos_token_id
:
int
,
forced_eos_token_id
:
int
,
)
->
FlaxLogitsProcessorList
:
"""
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
instances used to modify the scores of the language model head.
instances used to modify the scores of the language model head.
"""
"""
processors
=
FlaxLogitsProcessorList
()
processors
=
FlaxLogitsProcessorList
()
# init warp parameters
if
(
no_repeat_ngram_size
=
(
generation_config
.
min_length
is
not
None
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
and
generation_config
.
eos_token_id
is
not
None
)
and
generation_config
.
min_length
>
-
1
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
):
forced_bos_token_id
=
(
processors
.
append
(
forced_bos_token_id
if
forced_bos_token_id
is
not
None
else
self
.
config
.
forced_bos_token_id
FlaxMinLengthLogitsProcessor
(
generation_config
.
min_length
,
generation_config
.
eos_token_id
)
)
)
forced_eos_token_id
=
(
if
generation_config
.
forced_bos_token_id
is
not
None
:
forced_eos_token_id
if
forced_eos_token_id
is
not
None
else
self
.
config
.
forced_eos_token_id
processors
.
append
(
FlaxForcedBOSTokenLogitsProcessor
(
generation_config
.
forced_bos_token_id
))
)
if
generation_config
.
forced_eos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedEOSTokenLogitsProcessor
(
generation_config
.
max_length
,
generation_config
.
forced_eos_token_id
)
)
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
-
1
:
processors
.
append
(
FlaxMinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
if
forced_bos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedBOSTokenLogitsProcessor
(
forced_bos_token_id
))
if
forced_eos_token_id
is
not
None
:
processors
.
append
(
FlaxForcedEOSTokenLogitsProcessor
(
max_length
,
forced_eos_token_id
))
return
processors
return
processors
def
_greedy_search
(
def
_greedy_search
(
...
@@ -530,9 +536,9 @@ class FlaxGenerationMixin:
...
@@ -530,9 +536,9 @@ class FlaxGenerationMixin:
model_kwargs
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
):
):
# init values
# init values
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
...
@@ -618,9 +624,9 @@ class FlaxGenerationMixin:
...
@@ -618,9 +624,9 @@ class FlaxGenerationMixin:
model_kwargs
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
model_kwargs
:
Optional
[
Dict
[
str
,
jnp
.
ndarray
]]
=
None
,
):
):
# init values
# init values
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
prng_key
=
prng_key
if
prng_key
is
not
None
else
jax
.
random
.
PRNGKey
(
0
)
batch_size
,
cur_len
=
input_ids
.
shape
batch_size
,
cur_len
=
input_ids
.
shape
...
@@ -716,7 +722,7 @@ class FlaxGenerationMixin:
...
@@ -716,7 +722,7 @@ class FlaxGenerationMixin:
):
):
"""
"""
This beam search function is heavily inspired by Flax's official example:
This beam search function is heavily inspired by Flax's official example:
https://github.com/google/flax/blob/ma
ster
/examples/wmt/
train.py#L254
https://github.com/google/flax/blob/ma
in
/examples/wmt/
decode.py
"""
"""
def
flatten_beam_dim
(
tensor
):
def
flatten_beam_dim
(
tensor
):
...
@@ -751,11 +757,11 @@ class FlaxGenerationMixin:
...
@@ -751,11 +757,11 @@ class FlaxGenerationMixin:
return
jax
.
tree_util
.
tree_map
(
gather_fn
,
nested
)
return
jax
.
tree_util
.
tree_map
(
gather_fn
,
nested
)
# init values
# init values
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
generation_
config
.
length_penalty
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
generation_
config
.
early_stopping
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
batch_size
,
num_beams
,
cur_len
=
input_ids
.
shape
...
...
src/transformers/modeling_flax_utils.py
View file @
bc53fc62
...
@@ -33,7 +33,7 @@ from jax.random import PRNGKey
...
@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.dynamic_module_utils
import
custom_object_save
from
.generation
import
FlaxGenerationMixin
from
.generation
import
FlaxGenerationMixin
,
GenerationConfig
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
(
from
.utils
import
(
FLAX_WEIGHTS_INDEX_NAME
,
FLAX_WEIGHTS_INDEX_NAME
,
...
@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self
.
key
=
PRNGKey
(
seed
)
self
.
key
=
PRNGKey
(
seed
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
input_shape
=
input_shape
self
.
input_shape
=
input_shape
self
.
generation_config
=
GenerationConfig
.
from_model_config
(
config
)
if
self
.
can_generate
()
else
None
# To check if the model was intialized automatically.
# To check if the model was intialized automatically.
self
.
_is_initialized
=
_do_init
self
.
_is_initialized
=
_do_init
...
@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
# the state dict is unflattened to the match the format of model.params
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
def
can_generate
(
self
)
->
bool
:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
):
return
False
return
True
@
classmethod
@
classmethod
def
from_pretrained
(
def
from_pretrained
(
cls
,
cls
,
...
@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
)
# If it is a model with generation capabilities, attempt to load the generation config
if
model
.
can_generate
():
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
subfolder
=
subfolder
,
_from_auto
=
from_auto_class
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
)
except
OSError
:
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
)
pass
if
_do_init
:
if
_do_init
:
# set correct parameters
# set correct parameters
model
.
params
=
unflatten_dict
(
state
)
model
.
params
=
unflatten_dict
(
state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment