Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
92dfceb1
Unverified
Commit
92dfceb1
authored
Feb 27, 2023
by
Joao Gante
Committed by
GitHub
Feb 27, 2023
Browse files
Inheritance-based framework detection (#21784)
parent
7811bf7e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
35 deletions
+60
-35
src/transformers/pipelines/base.py
src/transformers/pipelines/base.py
+3
-2
src/transformers/utils/generic.py
src/transformers/utils/generic.py
+15
-11
tests/utils/test_file_utils.py
tests/utils/test_file_utils.py
+42
-22
No files found.
src/transformers/pipelines/base.py
View file @
92dfceb1
...
...
@@ -15,6 +15,7 @@
import
collections
import
csv
import
importlib
import
inspect
import
json
import
os
import
pickle
...
...
@@ -269,7 +270,7 @@ def infer_framework_load_model(
if
isinstance
(
model
,
str
):
raise
ValueError
(
f
"Could not load model
{
model
}
with any of the following classes:
{
class_tuple
}
."
)
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
framework
=
"tf"
if
"keras.engine.training.Model"
in
str
(
inspect
.
getmro
(
model
.
__class__
)
)
else
"pt"
return
framework
,
model
...
...
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
except
OSError
:
model
=
TFAutoModel
.
from_pretrained
(
model
,
revision
=
revision
)
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
framework
=
"tf"
if
"keras.engine.training.Model"
in
str
(
inspect
.
getmro
(
model
.
__class__
)
)
else
"pt"
return
framework
...
...
src/transformers/utils/generic.py
View file @
92dfceb1
...
...
@@ -366,13 +366,14 @@ def can_return_loss(model_class):
Args:
model_class (`type`): The class of the model.
"""
model_name
=
model_class
.
__name__
if
model_name
.
startswith
(
"TF"
):
signature
=
inspect
.
signature
(
model_class
.
call
)
elif
model_name
.
startswith
(
"Flax"
):
signature
=
inspect
.
signature
(
model_class
.
__call__
)
base_classes
=
str
(
inspect
.
getmro
(
model_class
))
if
"keras.engine.training.Model"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
call
)
# TensorFlow models
elif
"torch.nn.modules.module.Module"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
# PyTorch models
else
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
signature
=
inspect
.
signature
(
model_class
.
__call__
)
# Flax models
for
p
in
signature
.
parameters
:
if
p
==
"return_loss"
and
signature
.
parameters
[
p
].
default
is
True
:
...
...
@@ -389,12 +390,15 @@ def find_labels(model_class):
model_class (`type`): The class of the model.
"""
model_name
=
model_class
.
__name__
if
model_name
.
startswith
(
"TF"
):
signature
=
inspect
.
signature
(
model_class
.
call
)
elif
model_name
.
startswith
(
"Flax"
):
signature
=
inspect
.
signature
(
model_class
.
__call__
)
base_classes
=
str
(
inspect
.
getmro
(
model_class
))
if
"keras.engine.training.Model"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
call
)
# TensorFlow models
elif
"torch.nn.modules.module.Module"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
# PyTorch models
else
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
signature
=
inspect
.
signature
(
model_class
.
__call__
)
# Flax models
if
"QuestionAnswering"
in
model_name
:
return
[
p
for
p
in
signature
.
parameters
if
"label"
in
p
or
p
in
(
"start_positions"
,
"end_positions"
)]
else
:
...
...
tests/utils/test_file_utils.py
View file @
92dfceb1
...
...
@@ -21,10 +21,20 @@ import transformers
# Try to import everything from transformers to ensure every object can be loaded.
from
transformers
import
*
# noqa F406
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
,
require_flax
,
require_tf
,
require_torch
from
transformers.utils
import
ContextManagers
,
find_labels
,
is_flax_available
,
is_tf_available
,
is_torch_available
if
is_torch_available
():
from
transformers
import
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
if
is_tf_available
():
from
transformers
import
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
if
is_flax_available
():
from
transformers
import
FlaxBertForPreTraining
,
FlaxBertForQuestionAnswering
,
FlaxBertForSequenceClassification
MODEL_ID
=
DUMMY_UNKNOWN_IDENTIFIER
# An actual model hosted on huggingface.co
...
...
@@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
# The output should be wrapped with an English and French welcome and goodbye
self
.
assertEqual
(
mock_stdout
.
getvalue
(),
"Bonjour!
\n
Welcome!
\n
Transformers are awesome!
\n
Bye!
\n
Au revoir!
\n
"
)
def
test_find_labels
(
self
):
if
is_torch_available
():
from
transformers
import
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
@
require_torch
def
test_find_labels_pt
(
self
):
self
.
assertEqual
(
find_labels
(
BertForSequenceClassification
),
[
"labels"
])
self
.
assertEqual
(
find_labels
(
BertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
self
.
assertEqual
(
find_labels
(
BertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
if
is_tf_available
():
from
transformers
import
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
# find_labels works regardless of the class name (it detects the framework through inheritance)
class
DummyModel
(
BertForSequenceClassification
):
pass
self
.
assertEqual
(
find_labels
(
DummyModel
),
[
"labels"
])
@
require_tf
def
test_find_labels_tf
(
self
):
self
.
assertEqual
(
find_labels
(
TFBertForSequenceClassification
),
[
"labels"
])
self
.
assertEqual
(
find_labels
(
TFBertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
self
.
assertEqual
(
find_labels
(
TFBertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
if
is_flax_available
():
# Flax models don't have labels
from
transformers
import
(
FlaxBertForPreTraining
,
FlaxBertForQuestionAnswering
,
FlaxBertForSequenceClassification
,
)
# find_labels works regardless of the class name (it detects the framework through inheritance)
class
DummyModel
(
TFBertForSequenceClassification
):
pass
self
.
assertEqual
(
find_labels
(
DummyModel
),
[
"labels"
])
@
require_flax
def
test_find_labels_flax
(
self
):
# Flax models don't have labels
self
.
assertEqual
(
find_labels
(
FlaxBertForSequenceClassification
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForPreTraining
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForQuestionAnswering
),
[])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class
DummyModel
(
FlaxBertForSequenceClassification
):
pass
self
.
assertEqual
(
find_labels
(
DummyModel
),
[])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment