"scripts/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "7d121448b21ca05a2700e8e74f3fced179061af0"
Unverified Commit bc53fc62 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: FLAX uses `GenerationConfig` as the basis for `.generate()` parametrization (#21007)

parent 4f1c9d16
......@@ -15,6 +15,7 @@
# limitations under the License.
import copy
import inspect
import warnings
from functools import partial
......@@ -33,6 +34,7 @@ from ..models.auto import (
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, logging
from .configuration_utils import GenerationConfig
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
......@@ -136,6 +138,11 @@ class FlaxGenerationMixin:
`do_sample=False`.
"""
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
@staticmethod
def _run_loop_in_debug(cond_fn, body_fn, init_state):
"""
......@@ -171,7 +178,7 @@ class FlaxGenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not hasattr(self, "prepare_inputs_for_generation"):
if not self.can_generate():
generate_compatible_mappings = [
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
......@@ -211,27 +218,11 @@ class FlaxGenerationMixin:
def generate(
self,
input_ids: jnp.ndarray,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None,
generation_config: Optional[GenerationConfig] = None,
prng_key: Optional[jnp.ndarray] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
no_repeat_ngram_size: Optional[int] = None,
min_length: Optional[int] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
**model_kwargs,
**kwargs,
):
r"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
......@@ -246,100 +237,151 @@ class FlaxGenerationMixin:
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
</Tip>
For a complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
</Tip>
Parameters:
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
the prompt.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
trace (`bool`, *optional*, defaults to `True`):
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
considerably slower runtime.
params (`Dict[str, jnp.ndarray]`, *optional*):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`].
Examples:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> inputs = tokenizer(input_context, return_tensors="np")
>>> # generate candidates using sampling
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial sampling, modifying an existing generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, GenerationConfig
>>> import numpy as np
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
>>> # Sample up to 30 tokens
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
>>> generation_config.do_sample = True
>>> outputs = model.generate(
... input_ids, generation_config=generation_config, prng_key=np.asarray([0, 0], dtype=np.uint32)
... )
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Today I believe we can finally get a change in that system. The way I saw it was this: a few years ago, this company would not']
```
Beam-search decoding, using a freshly initialized generation configuration:
```python
>>> from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="np").input_ids
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# Validate the `.generate()` call
# Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# set init values
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if pad_token_id is None and eos_token_id is not None:
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask") is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
generation_config.pad_token_id = eos_token_id
if decoder_start_token_id is None and self.config.is_encoder_decoder:
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
if not self.config.is_encoder_decoder and not trace:
if pad_token_id is not None and jnp.sum(input_ids[:, -1] == pad_token_id) > 0:
if (
generation_config.pad_token_id is not None
and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
......@@ -350,71 +392,62 @@ class FlaxGenerationMixin:
if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to "
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
"using `max_new_tokens` to control the maximum length of the generation.",
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
elif has_default_max_length and generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length:
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= max_length:
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
"`max_new_tokens`."
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing`max_new_tokens`."
)
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams
logits_processor = self._get_logits_processor(generation_config=generation_config)
if not do_sample and num_beams == 1:
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
if not generation_config.do_sample and generation_config.num_beams == 1:
return self._greedy_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
generation_config.max_length,
generation_config.pad_token_id,
generation_config.eos_token_id,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif do_sample and num_beams == 1:
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
elif generation_config.do_sample and generation_config.num_beams == 1:
logits_warper = self._get_logits_warper(generation_config=generation_config)
return self._sample(
input_ids,
max_length,
pad_token_id,
eos_token_id,
generation_config.max_length,
generation_config.pad_token_id,
generation_config.eos_token_id,
prng_key,
logits_warper=logits_warper,
logits_processor=logits_processor,
......@@ -422,31 +455,27 @@ class FlaxGenerationMixin:
params=params,
model_kwargs=model_kwargs,
)
elif not do_sample and num_beams > 1:
elif not generation_config.do_sample and generation_config.num_beams > 1:
# broadcast input_ids & encoder_outputs
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
return self._beam_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
generation_config.max_length,
generation_config.pad_token_id,
generation_config.eos_token_id,
length_penalty=generation_config.length_penalty,
early_stopping=generation_config.early_stopping,
logits_processor=logits_processor,
trace=trace,
params=params,
......@@ -455,67 +484,44 @@ class FlaxGenerationMixin:
else:
raise NotImplementedError("`Beam sampling is currently not implemented.")
def _get_logits_warper(
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> FlaxLogitsProcessorList:
def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
instances used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = FlaxLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(FlaxTemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
if top_p is not None and top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
return warpers
def _get_logits_processor(
self,
no_repeat_ngram_size: int,
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
) -> FlaxLogitsProcessorList:
def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
"""
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
instances used to modify the scores of the language model head.
"""
processors = FlaxLogitsProcessorList()
# init warp parameters
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > -1
):
processors.append(
FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
if generation_config.forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
if generation_config.forced_eos_token_id is not None:
processors.append(
FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
)
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
if forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors
def _greedy_search(
......@@ -530,9 +536,9 @@ class FlaxGenerationMixin:
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
batch_size, cur_len = input_ids.shape
......@@ -618,9 +624,9 @@ class FlaxGenerationMixin:
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
batch_size, cur_len = input_ids.shape
......@@ -716,7 +722,7 @@ class FlaxGenerationMixin:
):
"""
This beam search function is heavily inspired by Flax's official example:
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
https://github.com/google/flax/blob/main/examples/wmt/decode.py
"""
def flatten_beam_dim(tensor):
......@@ -751,11 +757,11 @@ class FlaxGenerationMixin:
return jax.tree_util.tree_map(gather_fn, nested)
# init values
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
batch_size, num_beams, cur_len = input_ids.shape
......
......@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import FlaxGenerationMixin
from .generation import FlaxGenerationMixin, GenerationConfig
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
......@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self.key = PRNGKey(seed)
self.dtype = dtype
self.input_shape = input_shape
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# To check if the model was intialized automatically.
self._is_initialized = _do_init
......@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/")
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
@classmethod
def from_pretrained(
cls,
......@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
if _do_init:
# set correct parameters
model.params = unflatten_dict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment