Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
92dfceb1
Unverified
Commit
92dfceb1
authored
Feb 27, 2023
by
Joao Gante
Committed by
GitHub
Feb 27, 2023
Browse files
Inheritance-based framework detection (#21784)
parent
7811bf7e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
35 deletions
+60
-35
src/transformers/pipelines/base.py
src/transformers/pipelines/base.py
+3
-2
src/transformers/utils/generic.py
src/transformers/utils/generic.py
+15
-11
tests/utils/test_file_utils.py
tests/utils/test_file_utils.py
+42
-22
No files found.
src/transformers/pipelines/base.py
View file @
92dfceb1
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
collections
import
collections
import
csv
import
csv
import
importlib
import
importlib
import
inspect
import
json
import
json
import
os
import
os
import
pickle
import
pickle
...
@@ -269,7 +270,7 @@ def infer_framework_load_model(
...
@@ -269,7 +270,7 @@ def infer_framework_load_model(
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
raise
ValueError
(
f
"Could not load model
{
model
}
with any of the following classes:
{
class_tuple
}
."
)
raise
ValueError
(
f
"Could not load model
{
model
}
with any of the following classes:
{
class_tuple
}
."
)
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
framework
=
"tf"
if
"keras.engine.training.Model"
in
str
(
inspect
.
getmro
(
model
.
__class__
)
)
else
"pt"
return
framework
,
model
return
framework
,
model
...
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
...
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
except
OSError
:
except
OSError
:
model
=
TFAutoModel
.
from_pretrained
(
model
,
revision
=
revision
)
model
=
TFAutoModel
.
from_pretrained
(
model
,
revision
=
revision
)
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
framework
=
"tf"
if
"keras.engine.training.Model"
in
str
(
inspect
.
getmro
(
model
.
__class__
)
)
else
"pt"
return
framework
return
framework
...
...
src/transformers/utils/generic.py
View file @
92dfceb1
...
@@ -366,13 +366,14 @@ def can_return_loss(model_class):
...
@@ -366,13 +366,14 @@ def can_return_loss(model_class):
Args:
Args:
model_class (`type`): The class of the model.
model_class (`type`): The class of the model.
"""
"""
model_name
=
model_class
.
__name__
base_classes
=
str
(
inspect
.
getmro
(
model_class
))
if
model_name
.
startswith
(
"TF"
):
signature
=
inspect
.
signature
(
model_class
.
call
)
if
"keras.engine.training.Model"
in
base_classes
:
elif
model_name
.
startswith
(
"Flax"
):
signature
=
inspect
.
signature
(
model_class
.
call
)
# TensorFlow models
signature
=
inspect
.
signature
(
model_class
.
__call__
)
elif
"torch.nn.modules.module.Module"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
# PyTorch models
else
:
else
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
signature
=
inspect
.
signature
(
model_class
.
__call__
)
# Flax models
for
p
in
signature
.
parameters
:
for
p
in
signature
.
parameters
:
if
p
==
"return_loss"
and
signature
.
parameters
[
p
].
default
is
True
:
if
p
==
"return_loss"
and
signature
.
parameters
[
p
].
default
is
True
:
...
@@ -389,12 +390,15 @@ def find_labels(model_class):
...
@@ -389,12 +390,15 @@ def find_labels(model_class):
model_class (`type`): The class of the model.
model_class (`type`): The class of the model.
"""
"""
model_name
=
model_class
.
__name__
model_name
=
model_class
.
__name__
if
model_name
.
startswith
(
"TF"
):
base_classes
=
str
(
inspect
.
getmro
(
model_class
))
signature
=
inspect
.
signature
(
model_class
.
call
)
elif
model_name
.
startswith
(
"Flax"
):
if
"keras.engine.training.Model"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
__call__
)
signature
=
inspect
.
signature
(
model_class
.
call
)
# TensorFlow models
elif
"torch.nn.modules.module.Module"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
# PyTorch models
else
:
else
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
signature
=
inspect
.
signature
(
model_class
.
__call__
)
# Flax models
if
"QuestionAnswering"
in
model_name
:
if
"QuestionAnswering"
in
model_name
:
return
[
p
for
p
in
signature
.
parameters
if
"label"
in
p
or
p
in
(
"start_positions"
,
"end_positions"
)]
return
[
p
for
p
in
signature
.
parameters
if
"label"
in
p
or
p
in
(
"start_positions"
,
"end_positions"
)]
else
:
else
:
...
...
tests/utils/test_file_utils.py
View file @
92dfceb1
...
@@ -21,10 +21,20 @@ import transformers
...
@@ -21,10 +21,20 @@ import transformers
# Try to import everything from transformers to ensure every object can be loaded.
# Try to import everything from transformers to ensure every object can be loaded.
from
transformers
import
*
# noqa F406
from
transformers
import
*
# noqa F406
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
,
require_flax
,
require_tf
,
require_torch
from
transformers.utils
import
ContextManagers
,
find_labels
,
is_flax_available
,
is_tf_available
,
is_torch_available
from
transformers.utils
import
ContextManagers
,
find_labels
,
is_flax_available
,
is_tf_available
,
is_torch_available
if
is_torch_available
():
from
transformers
import
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
if
is_tf_available
():
from
transformers
import
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
if
is_flax_available
():
from
transformers
import
FlaxBertForPreTraining
,
FlaxBertForQuestionAnswering
,
FlaxBertForSequenceClassification
MODEL_ID
=
DUMMY_UNKNOWN_IDENTIFIER
MODEL_ID
=
DUMMY_UNKNOWN_IDENTIFIER
# An actual model hosted on huggingface.co
# An actual model hosted on huggingface.co
...
@@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
...
@@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
# The output should be wrapped with an English and French welcome and goodbye
# The output should be wrapped with an English and French welcome and goodbye
self
.
assertEqual
(
mock_stdout
.
getvalue
(),
"Bonjour!
\n
Welcome!
\n
Transformers are awesome!
\n
Bye!
\n
Au revoir!
\n
"
)
self
.
assertEqual
(
mock_stdout
.
getvalue
(),
"Bonjour!
\n
Welcome!
\n
Transformers are awesome!
\n
Bye!
\n
Au revoir!
\n
"
)
def
test_find_labels
(
self
):
@
require_torch
if
is_torch_available
():
def
test_find_labels_pt
(
self
):
from
transformers
import
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
self
.
assertEqual
(
find_labels
(
BertForSequenceClassification
),
[
"labels"
])
self
.
assertEqual
(
find_labels
(
BertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
self
.
assertEqual
(
find_labels
(
BertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class
DummyModel
(
BertForSequenceClassification
):
pass
self
.
assertEqual
(
find_labels
(
DummyModel
),
[
"labels"
])
@
require_tf
def
test_find_labels_tf
(
self
):
self
.
assertEqual
(
find_labels
(
TFBertForSequenceClassification
),
[
"labels"
])
self
.
assertEqual
(
find_labels
(
TFBertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
self
.
assertEqual
(
find_labels
(
TFBertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
self
.
assertEqual
(
find_labels
(
BertForSequenceClassification
),
[
"labels"
]
)
# find_labels works regardless of the class name (it detects the framework through inheritance
)
self
.
assertEqual
(
find_labels
(
BertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
class
DummyModel
(
TFBertForSequenceClassification
):
self
.
assertEqual
(
find_labels
(
BertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
pass
if
is_tf_available
():
self
.
assertEqual
(
find_labels
(
DummyModel
),
[
"labels"
])
from
transformers
import
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
self
.
assertEqual
(
find_labels
(
TFBertForSequenceClassification
),
[
"labels"
])
@
require_flax
self
.
assertEqual
(
find_labels
(
TFBertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
def
test_find_labels_flax
(
self
):
self
.
assertEqual
(
find_labels
(
TFBertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
# Flax models don't have labels
self
.
assertEqual
(
find_labels
(
FlaxBertForSequenceClassification
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForPreTraining
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForQuestionAnswering
),
[])
if
is_flax_available
():
# find_labels works regardless of the class name (it detects the framework through inheritance)
# Flax models don't have labels
class
DummyModel
(
FlaxBertForSequenceClassification
):
from
transformers
import
(
pass
FlaxBertForPreTraining
,
FlaxBertForQuestionAnswering
,
FlaxBertForSequenceClassification
,
)
self
.
assertEqual
(
find_labels
(
FlaxBertForSequenceClassification
),
[])
self
.
assertEqual
(
find_labels
(
DummyModel
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForPreTraining
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForQuestionAnswering
),
[])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment