"docs/source/vscode:/vscode.git/clone" did not exist on "adc0ff25028d29af30386f2d7d3f85e290fbef57"
Unverified Commit bc53fc62 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: FLAX uses `GenerationConfig` as the basis for `.generate()` parametrization (#21007)

parent 4f1c9d16
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import copy
import inspect import inspect
import warnings import warnings
from functools import partial from functools import partial
...@@ -33,6 +34,7 @@ from ..models.auto import ( ...@@ -33,6 +34,7 @@ from ..models.auto import (
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
) )
from ..utils import ModelOutput, logging from ..utils import ModelOutput, logging
from .configuration_utils import GenerationConfig
from .flax_logits_process import ( from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
...@@ -136,6 +138,11 @@ class FlaxGenerationMixin: ...@@ -136,6 +138,11 @@ class FlaxGenerationMixin:
`do_sample=False`. `do_sample=False`.
""" """
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
@staticmethod @staticmethod
def _run_loop_in_debug(cond_fn, body_fn, init_state): def _run_loop_in_debug(cond_fn, body_fn, init_state):
""" """
...@@ -171,7 +178,7 @@ class FlaxGenerationMixin: ...@@ -171,7 +178,7 @@ class FlaxGenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use. right class to use.
""" """
if not hasattr(self, "prepare_inputs_for_generation"): if not self.can_generate():
generate_compatible_mappings = [ generate_compatible_mappings = [
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
...@@ -211,27 +218,11 @@ class FlaxGenerationMixin: ...@@ -211,27 +218,11 @@ class FlaxGenerationMixin:
def generate( def generate(
self, self,
input_ids: jnp.ndarray, input_ids: jnp.ndarray,
max_length: Optional[int] = None, generation_config: Optional[GenerationConfig] = None,
max_new_tokens: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None,
prng_key: Optional[jnp.ndarray] = None, prng_key: Optional[jnp.ndarray] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
no_repeat_ngram_size: Optional[int] = None,
min_length: Optional[int] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
**model_kwargs, **kwargs,
): ):
r""" r"""
Generates sequences of token ids for models with a language modeling head. The method supports the following Generates sequences of token ids for models with a language modeling head. The method supports the following
...@@ -246,100 +237,151 @@ class FlaxGenerationMixin: ...@@ -246,100 +237,151 @@ class FlaxGenerationMixin:
<Tip warning={true}> <Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
defined in the model's config (`config.json`) which in turn defaults to the model's default generation configuration. You can override any `generation_config` by passing the corresponding
[`~modeling_utils.PretrainedConfig`] of the model. parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
</Tip> For a complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
Most of these parameters are explained in more detail in [this blog </Tip>
post](https://huggingface.co/blog/how-to-generate).
Parameters: Parameters:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to `model.config.max_length`): generation_config (`~generation.GenerationConfig`, *optional*):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt + The generation configuration to be used as base parametrization for the generation call. `**kwargs`
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in passed to generate matching the attributes of `generation_config` will override them. If
the prompt. `generation_config` is not provided, the default will be used, which had the following loading
max_new_tokens (`int`, *optional*): priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
do_sample (`bool`, *optional*, defaults to `False`): default values, whose documentation should be checked to parameterize generation.
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
trace (`bool`, *optional*, defaults to `True`): trace (`bool`, *optional*, defaults to `True`):
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
considerably slower runtime. considerably slower runtime.
params (`Dict[str, jnp.ndarray]`, *optional*): params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation. Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs: kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part. specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return: Return:
[`~utils.ModelOutput`]. [`~utils.ModelOutput`].
Examples: Examples:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python ```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") >>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> input_context = "The dog"
>>> # encode input context >>> prompt = "Today I believe we can finally"
>>> inputs = tokenizer(input_context, return_tensors="np") >>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True) >>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial sampling, modifying an existing generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, GenerationConfig
>>> import numpy as np
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # Sample up to 30 tokens
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
>>> generation_config.do_sample = True
>>> outputs = model.generate(
... input_ids, generation_config=generation_config, prng_key=np.asarray([0, 0], dtype=np.uint32)
... )
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get a change in that system. The way I saw it was this: a few years ago, this company would not']
```
Beam-search decoding, using a freshly initialized generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="np").input_ids
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```""" ```"""
# Validate the `.generate()` call # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# set init values # set init values
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if pad_token_id is None and eos_token_id is not None: if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask") is None: if model_kwargs.get("attention_mask") is None:
logger.warning( logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe " "The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
) )
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list): if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0] eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
if decoder_start_token_id is None and self.config.is_encoder_decoder: if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`) # decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
if not self.config.is_encoder_decoder and not trace: if not self.config.is_encoder_decoder and not trace:
if pad_token_id is not None and jnp.sum(input_ids[:, -1] == pad_token_id) > 0: if (
generation_config.pad_token_id is not None
and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning( logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct " "A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer." "generation results, please set `padding_side='left'` when initializing the tokenizer."
...@@ -350,71 +392,62 @@ class FlaxGenerationMixin: ...@@ -350,71 +392,62 @@ class FlaxGenerationMixin:
if model_kwargs.get("encoder_outputs") is None: if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation # prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id
# Prepare `max_length` depending on other stopping criteria. # Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None: has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn( warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to " "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
"using `max_new_tokens` to control the maximum length of the generation.", " recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif max_length is None and max_new_tokens is not None: elif has_default_max_length and generation_config.max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None: elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError( raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the" " limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. " " documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length: if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError( raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
f"length ({max_length})" f" the maximum length ({generation_config.max_length})"
) )
if input_ids_seq_length >= max_length: if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning( logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
"`max_new_tokens`." " increasing`max_new_tokens`."
) )
do_sample = do_sample if do_sample is not None else self.config.do_sample logits_processor = self._get_logits_processor(generation_config=generation_config)
num_beams = num_beams if num_beams is not None else self.config.num_beams
if not do_sample and num_beams == 1: if not generation_config.do_sample and generation_config.num_beams == 1:
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._greedy_search( return self._greedy_search(
input_ids, input_ids,
max_length, generation_config.max_length,
pad_token_id, generation_config.pad_token_id,
eos_token_id, generation_config.eos_token_id,
logits_processor=logits_processor, logits_processor=logits_processor,
trace=trace, trace=trace,
params=params, params=params,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
elif do_sample and num_beams == 1: elif generation_config.do_sample and generation_config.num_beams == 1:
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) logits_warper = self._get_logits_warper(generation_config=generation_config)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._sample( return self._sample(
input_ids, input_ids,
max_length, generation_config.max_length,
pad_token_id, generation_config.pad_token_id,
eos_token_id, generation_config.eos_token_id,
prng_key, prng_key,
logits_warper=logits_warper, logits_warper=logits_warper,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -422,31 +455,27 @@ class FlaxGenerationMixin: ...@@ -422,31 +455,27 @@ class FlaxGenerationMixin:
params=params, params=params,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
elif not do_sample and num_beams > 1: elif not generation_config.do_sample and generation_config.num_beams > 1:
# broadcast input_ids & encoder_outputs # broadcast input_ids & encoder_outputs
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
if "encoder_outputs" in model_kwargs: if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
) )
if "attention_mask" in model_kwargs: if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams( model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
) )
return self._beam_search( return self._beam_search(
input_ids, input_ids,
max_length, generation_config.max_length,
pad_token_id, generation_config.pad_token_id,
eos_token_id, generation_config.eos_token_id,
length_penalty=length_penalty, length_penalty=generation_config.length_penalty,
early_stopping=early_stopping, early_stopping=generation_config.early_stopping,
logits_processor=logits_processor, logits_processor=logits_processor,
trace=trace, trace=trace,
params=params, params=params,
...@@ -455,67 +484,44 @@ class FlaxGenerationMixin: ...@@ -455,67 +484,44 @@ class FlaxGenerationMixin:
else: else:
raise NotImplementedError("`Beam sampling is currently not implemented.") raise NotImplementedError("`Beam sampling is currently not implemented.")
def _get_logits_warper( def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> FlaxLogitsProcessorList:
""" """
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`] This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
instances used for multinomial sampling. instances used for multinomial sampling.
""" """
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = FlaxLogitsProcessorList() warpers = FlaxLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files if generation_config.temperature is not None and generation_config.temperature != 1.0:
# all samplers can be found in `generation_utils_samplers.py` warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature))
if temperature is not None and temperature != 1.0: if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(FlaxTemperatureLogitsWarper(temperature)) warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
if top_k is not None and top_k != 0: if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1)) warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
if top_p is not None and top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
return warpers return warpers
def _get_logits_processor( def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
self,
no_repeat_ngram_size: int,
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
) -> FlaxLogitsProcessorList:
""" """
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
instances used to modify the scores of the language model head. instances used to modify the scores of the language model head.
""" """
processors = FlaxLogitsProcessorList() processors = FlaxLogitsProcessorList()
# init warp parameters if (
no_repeat_ngram_size = ( generation_config.min_length is not None
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size and generation_config.eos_token_id is not None
) and generation_config.min_length > -1
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id ):
forced_bos_token_id = ( processors.append(
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)
) )
forced_eos_token_id = ( if generation_config.forced_bos_token_id is not None:
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
if generation_config.forced_eos_token_id is not None:
processors.append(
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
) )
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
if forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors return processors
def _greedy_search( def _greedy_search(
...@@ -530,9 +536,9 @@ class FlaxGenerationMixin: ...@@ -530,9 +536,9 @@ class FlaxGenerationMixin:
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
batch_size, cur_len = input_ids.shape batch_size, cur_len = input_ids.shape
...@@ -618,9 +624,9 @@ class FlaxGenerationMixin: ...@@ -618,9 +624,9 @@ class FlaxGenerationMixin:
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
batch_size, cur_len = input_ids.shape batch_size, cur_len = input_ids.shape
...@@ -716,7 +722,7 @@ class FlaxGenerationMixin: ...@@ -716,7 +722,7 @@ class FlaxGenerationMixin:
): ):
""" """
This beam search function is heavily inspired by Flax's official example: This beam search function is heavily inspired by Flax's official example:
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 https://github.com/google/flax/blob/main/examples/wmt/decode.py
""" """
def flatten_beam_dim(tensor): def flatten_beam_dim(tensor):
...@@ -751,11 +757,11 @@ class FlaxGenerationMixin: ...@@ -751,11 +757,11 @@ class FlaxGenerationMixin:
return jax.tree_util.tree_map(gather_fn, nested) return jax.tree_util.tree_map(gather_fn, nested)
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
batch_size, num_beams, cur_len = input_ids.shape batch_size, num_beams, cur_len = input_ids.shape
......
...@@ -33,7 +33,7 @@ from jax.random import PRNGKey ...@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import FlaxGenerationMixin from .generation import FlaxGenerationMixin, GenerationConfig
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import ( from .utils import (
FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_INDEX_NAME,
...@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self.key = PRNGKey(seed) self.key = PRNGKey(seed)
self.dtype = dtype self.dtype = dtype
self.input_shape = input_shape self.input_shape = input_shape
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# To check if the model was intialized automatically. # To check if the model was intialized automatically.
self._is_initialized = _do_init self._is_initialized = _do_init
...@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params # the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/") return unflatten_dict(state_sharded_dict, sep="/")
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
) )
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
if _do_init: if _do_init:
# set correct parameters # set correct parameters
model.params = unflatten_dict(state) model.params = unflatten_dict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment