Unverified Commit bc53fc62 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: FLAX uses `GenerationConfig` as the basis for `.generate()` parametrization (#21007)

parent 4f1c9d16
This diff is collapsed.
......@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import FlaxGenerationMixin
from .generation import FlaxGenerationMixin, GenerationConfig
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
......@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self.key = PRNGKey(seed)
self.dtype = dtype
self.input_shape = input_shape
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# To check if the model was intialized automatically.
self._is_initialized = _do_init
......@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
return unflatten_dict(state_sharded_dict, sep="/")
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
@classmethod
def from_pretrained(
cls,
......@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
if _do_init:
# set correct parameters
model.params = unflatten_dict(state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment