Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc53fc62
Unverified
Commit
bc53fc62
authored
Jan 05, 2023
by
Joao Gante
Committed by
GitHub
Jan 05, 2023
Browse files
Generate: FLAX uses `GenerationConfig` as the basis for `.generate()` parametrization (#21007)
parent
4f1c9d16
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
218 additions
and
178 deletions
+218
-178
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+183
-177
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+35
-1
No files found.
src/transformers/generation/flax_utils.py
View file @
bc53fc62
This diff is collapsed.
Click to expand it.
src/transformers/modeling_flax_utils.py
View file @
bc53fc62
...
...
@@ -33,7 +33,7 @@ from jax.random import PRNGKey
from
.configuration_utils
import
PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.generation
import
FlaxGenerationMixin
from
.generation
import
FlaxGenerationMixin
,
GenerationConfig
from
.modeling_flax_pytorch_utils
import
load_pytorch_checkpoint_in_flax_state_dict
from
.utils
import
(
FLAX_WEIGHTS_INDEX_NAME
,
...
...
@@ -199,6 +199,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self
.
key
=
PRNGKey
(
seed
)
self
.
dtype
=
dtype
self
.
input_shape
=
input_shape
self
.
generation_config
=
GenerationConfig
.
from_model_config
(
config
)
if
self
.
can_generate
()
else
None
# To check if the model was intialized automatically.
self
.
_is_initialized
=
_do_init
...
...
@@ -467,6 +468,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
def
can_generate
(
self
)
->
bool
:
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
):
return
False
return
True
@
classmethod
def
from_pretrained
(
cls
,
...
...
@@ -940,6 +951,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)
# If it is a model with generation capabilities, attempt to load the generation config
if
model
.
can_generate
():
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
subfolder
=
subfolder
,
_from_auto
=
from_auto_class
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
)
except
OSError
:
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
)
pass
if
_do_init
:
# set correct parameters
model
.
params
=
unflatten_dict
(
state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment