Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6134b9b4
Unverified
Commit
6134b9b4
authored
Jun 15, 2023
by
Yih-Dar
Committed by
GitHub
Jun 15, 2023
Browse files
Make `can_generate` as class method (#24299)
fix Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
e45bc143
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
6 deletions
+9
-6
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+3
-2
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+3
-2
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+3
-2
No files found.
src/transformers/modeling_flax_utils.py
View file @
6134b9b4
...
@@ -468,13 +468,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -468,13 +468,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# the state dict is unflattened to the match the format of model.params
# the state dict is unflattened to the match the format of model.params
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
return
unflatten_dict
(
state_sharded_dict
,
sep
=
"/"
)
def
can_generate
(
self
)
->
bool
:
@
classmethod
def
can_generate
(
cls
)
->
bool
:
"""
"""
Returns whether this model can generate sequences with `.generate()`. Returns:
Returns whether this model can generate sequences with `.generate()`. Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
`bool`: Whether this model can generate sequences with `.generate()`.
"""
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
.
__func__
):
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
):
return
False
return
False
return
True
return
True
...
...
src/transformers/modeling_tf_utils.py
View file @
6134b9b4
...
@@ -1328,7 +1328,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1328,7 +1328,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
pass
# Layers may not have the same dimensions
pass
# Layers may not have the same dimensions
return
output
return
output
def
can_generate
(
self
)
->
bool
:
@
classmethod
def
can_generate
(
cls
)
->
bool
:
"""
"""
Returns whether this model can generate sequences with `.generate()`.
Returns whether this model can generate sequences with `.generate()`.
...
@@ -1336,7 +1337,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1336,7 +1337,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`bool`: Whether this model can generate sequences with `.generate()`.
`bool`: Whether this model can generate sequences with `.generate()`.
"""
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
.
__func__
):
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
):
return
False
return
False
return
True
return
True
...
...
src/transformers/modeling_utils.py
View file @
6134b9b4
...
@@ -1174,7 +1174,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1174,7 +1174,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
"""
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
def
can_generate
(
self
)
->
bool
:
@
classmethod
def
can_generate
(
cls
)
->
bool
:
"""
"""
Returns whether this model can generate sequences with `.generate()`.
Returns whether this model can generate sequences with `.generate()`.
...
@@ -1182,7 +1183,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1182,7 +1183,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
`bool`: Whether this model can generate sequences with `.generate()`.
`bool`: Whether this model can generate sequences with `.generate()`.
"""
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
.
__func__
):
if
"GenerationMixin"
in
str
(
cls
.
prepare_inputs_for_generation
):
return
False
return
False
return
True
return
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment