Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
92dfceb1
"configs/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "a5b219127fc4d316229775d956eba6a62fcebecd"
Unverified
Commit
92dfceb1
authored
Feb 27, 2023
by
Joao Gante
Committed by
GitHub
Feb 27, 2023
Browse files
Inheritance-based framework detection (#21784)
parent
7811bf7e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
35 deletions
+60
-35
src/transformers/pipelines/base.py
src/transformers/pipelines/base.py
+3
-2
src/transformers/utils/generic.py
src/transformers/utils/generic.py
+15
-11
tests/utils/test_file_utils.py
tests/utils/test_file_utils.py
+42
-22
No files found.
src/transformers/pipelines/base.py
View file @
92dfceb1
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
collections
import
collections
import
csv
import
csv
import
importlib
import
importlib
import
inspect
import
json
import
json
import
os
import
os
import
pickle
import
pickle
...
@@ -269,7 +270,7 @@ def infer_framework_load_model(
...
@@ -269,7 +270,7 @@ def infer_framework_load_model(
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
raise
ValueError
(
f
"Could not load model
{
model
}
with any of the following classes:
{
class_tuple
}
."
)
raise
ValueError
(
f
"Could not load model
{
model
}
with any of the following classes:
{
class_tuple
}
."
)
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
framework
=
"tf"
if
"keras.engine.training.Model"
in
str
(
inspect
.
getmro
(
model
.
__class__
)
)
else
"pt"
return
framework
,
model
return
framework
,
model
...
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
...
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
except
OSError
:
except
OSError
:
model
=
TFAutoModel
.
from_pretrained
(
model
,
revision
=
revision
)
model
=
TFAutoModel
.
from_pretrained
(
model
,
revision
=
revision
)
framework
=
"tf"
if
model
.
__class__
.
__name__
.
startswith
(
"TF"
)
else
"pt"
framework
=
"tf"
if
"keras.engine.training.Model"
in
str
(
inspect
.
getmro
(
model
.
__class__
)
)
else
"pt"
return
framework
return
framework
...
...
src/transformers/utils/generic.py
View file @
92dfceb1
...
@@ -366,13 +366,14 @@ def can_return_loss(model_class):
...
@@ -366,13 +366,14 @@ def can_return_loss(model_class):
Args:
Args:
model_class (`type`): The class of the model.
model_class (`type`): The class of the model.
"""
"""
model_name
=
model_class
.
__name__
base_classes
=
str
(
inspect
.
getmro
(
model_class
))
if
model_name
.
startswith
(
"TF"
):
signature
=
inspect
.
signature
(
model_class
.
call
)
if
"keras.engine.training.Model"
in
base_classes
:
elif
model_name
.
startswith
(
"Flax"
):
signature
=
inspect
.
signature
(
model_class
.
call
)
# TensorFlow models
signature
=
inspect
.
signature
(
model_class
.
__call__
)
elif
"torch.nn.modules.module.Module"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
# PyTorch models
else
:
else
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
signature
=
inspect
.
signature
(
model_class
.
__call__
)
# Flax models
for
p
in
signature
.
parameters
:
for
p
in
signature
.
parameters
:
if
p
==
"return_loss"
and
signature
.
parameters
[
p
].
default
is
True
:
if
p
==
"return_loss"
and
signature
.
parameters
[
p
].
default
is
True
:
...
@@ -389,12 +390,15 @@ def find_labels(model_class):
...
@@ -389,12 +390,15 @@ def find_labels(model_class):
model_class (`type`): The class of the model.
model_class (`type`): The class of the model.
"""
"""
model_name
=
model_class
.
__name__
model_name
=
model_class
.
__name__
if
model_name
.
startswith
(
"TF"
):
base_classes
=
str
(
inspect
.
getmro
(
model_class
))
signature
=
inspect
.
signature
(
model_class
.
call
)
elif
model_name
.
startswith
(
"Flax"
):
if
"keras.engine.training.Model"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
__call__
)
signature
=
inspect
.
signature
(
model_class
.
call
)
# TensorFlow models
elif
"torch.nn.modules.module.Module"
in
base_classes
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
# PyTorch models
else
:
else
:
signature
=
inspect
.
signature
(
model_class
.
forward
)
signature
=
inspect
.
signature
(
model_class
.
__call__
)
# Flax models
if
"QuestionAnswering"
in
model_name
:
if
"QuestionAnswering"
in
model_name
:
return
[
p
for
p
in
signature
.
parameters
if
"label"
in
p
or
p
in
(
"start_positions"
,
"end_positions"
)]
return
[
p
for
p
in
signature
.
parameters
if
"label"
in
p
or
p
in
(
"start_positions"
,
"end_positions"
)]
else
:
else
:
...
...
tests/utils/test_file_utils.py
View file @
92dfceb1
...
@@ -21,10 +21,20 @@ import transformers
...
@@ -21,10 +21,20 @@ import transformers
# Try to import everything from transformers to ensure every object can be loaded.
# Try to import everything from transformers to ensure every object can be loaded.
from
transformers
import
*
# noqa F406
from
transformers
import
*
# noqa F406
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
from
transformers.testing_utils
import
DUMMY_UNKNOWN_IDENTIFIER
,
require_flax
,
require_tf
,
require_torch
from
transformers.utils
import
ContextManagers
,
find_labels
,
is_flax_available
,
is_tf_available
,
is_torch_available
from
transformers.utils
import
ContextManagers
,
find_labels
,
is_flax_available
,
is_tf_available
,
is_torch_available
if
is_torch_available
():
from
transformers
import
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
if
is_tf_available
():
from
transformers
import
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
if
is_flax_available
():
from
transformers
import
FlaxBertForPreTraining
,
FlaxBertForQuestionAnswering
,
FlaxBertForSequenceClassification
MODEL_ID
=
DUMMY_UNKNOWN_IDENTIFIER
MODEL_ID
=
DUMMY_UNKNOWN_IDENTIFIER
# An actual model hosted on huggingface.co
# An actual model hosted on huggingface.co
...
@@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
...
@@ -85,29 +95,39 @@ class GenericUtilTests(unittest.TestCase):
# The output should be wrapped with an English and French welcome and goodbye
# The output should be wrapped with an English and French welcome and goodbye
self
.
assertEqual
(
mock_stdout
.
getvalue
(),
"Bonjour!
\n
Welcome!
\n
Transformers are awesome!
\n
Bye!
\n
Au revoir!
\n
"
)
self
.
assertEqual
(
mock_stdout
.
getvalue
(),
"Bonjour!
\n
Welcome!
\n
Transformers are awesome!
\n
Bye!
\n
Au revoir!
\n
"
)
def
test_find_labels
(
self
):
@
require_torch
if
is_torch_available
():
def
test_find_labels_pt
(
self
):
from
transformers
import
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
self
.
assertEqual
(
find_labels
(
BertForSequenceClassification
),
[
"labels"
])
self
.
assertEqual
(
find_labels
(
BertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
self
.
assertEqual
(
find_labels
(
BertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class
DummyModel
(
BertForSequenceClassification
):
pass
self
.
assertEqual
(
find_labels
(
DummyModel
),
[
"labels"
])
@
require_tf
def
test_find_labels_tf
(
self
):
self
.
assertEqual
(
find_labels
(
TFBertForSequenceClassification
),
[
"labels"
])
self
.
assertEqual
(
find_labels
(
TFBertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
self
.
assertEqual
(
find_labels
(
TFBertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
self
.
assertEqual
(
find_labels
(
BertForSequenceClassification
),
[
"labels"
]
)
# find_labels works regardless of the class name (it detects the framework through inheritance
)
self
.
assertEqual
(
find_labels
(
BertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
class
DummyModel
(
TFBertForSequenceClassification
):
self
.
assertEqual
(
find_labels
(
BertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
pass
if
is_tf_available
():
self
.
assertEqual
(
find_labels
(
DummyModel
),
[
"labels"
])
from
transformers
import
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
self
.
assertEqual
(
find_labels
(
TFBertForSequenceClassification
),
[
"labels"
])
@
require_flax
self
.
assertEqual
(
find_labels
(
TFBertForPreTraining
),
[
"labels"
,
"next_sentence_label"
])
def
test_find_labels_flax
(
self
):
self
.
assertEqual
(
find_labels
(
TFBertForQuestionAnswering
),
[
"start_positions"
,
"end_positions"
])
# Flax models don't have labels
self
.
assertEqual
(
find_labels
(
FlaxBertForSequenceClassification
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForPreTraining
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForQuestionAnswering
),
[])
if
is_flax_available
():
# find_labels works regardless of the class name (it detects the framework through inheritance)
# Flax models don't have labels
class
DummyModel
(
FlaxBertForSequenceClassification
):
from
transformers
import
(
pass
FlaxBertForPreTraining
,
FlaxBertForQuestionAnswering
,
FlaxBertForSequenceClassification
,
)
self
.
assertEqual
(
find_labels
(
FlaxBertForSequenceClassification
),
[])
self
.
assertEqual
(
find_labels
(
DummyModel
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForPreTraining
),
[])
self
.
assertEqual
(
find_labels
(
FlaxBertForQuestionAnswering
),
[])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment