Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a39218b7
Unverified
Commit
a39218b7
authored
Nov 09, 2020
by
Sylvain Gugger
Committed by
GitHub
Nov 09, 2020
Browse files
Check all models are in an auto class (#8425)
parent
ef032ddd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
0 deletions
+73
-0
src/transformers/modeling_tf_auto.py
src/transformers/modeling_tf_auto.py
+4
-0
utils/check_repo.py
utils/check_repo.py
+69
-0
No files found.
src/transformers/modeling_tf_auto.py
View file @
a39218b7
...
...
@@ -31,6 +31,7 @@ from .configuration_auto import (
FunnelConfig
,
GPT2Config
,
LongformerConfig
,
LxmertConfig
,
MobileBertConfig
,
OpenAIGPTConfig
,
RobertaConfig
,
...
...
@@ -113,6 +114,7 @@ from .modeling_tf_funnel import (
)
from
.modeling_tf_gpt2
import
TFGPT2LMHeadModel
,
TFGPT2Model
from
.modeling_tf_longformer
import
TFLongformerForMaskedLM
,
TFLongformerForQuestionAnswering
,
TFLongformerModel
from
.modeling_tf_lxmert
import
TFLxmertForPreTraining
,
TFLxmertModel
from
.modeling_tf_marian
import
TFMarianMTModel
from
.modeling_tf_mbart
import
TFMBartForConditionalGeneration
from
.modeling_tf_mobilebert
import
(
...
...
@@ -168,6 +170,7 @@ logger = logging.get_logger(__name__)
TF_MODEL_MAPPING
=
OrderedDict
(
[
(
LxmertConfig
,
TFLxmertModel
),
(
T5Config
,
TFT5Model
),
(
DistilBertConfig
,
TFDistilBertModel
),
(
AlbertConfig
,
TFAlbertModel
),
...
...
@@ -192,6 +195,7 @@ TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING
=
OrderedDict
(
[
(
LxmertConfig
,
TFLxmertForPreTraining
),
(
T5Config
,
TFT5ForConditionalGeneration
),
(
DistilBertConfig
,
TFDistilBertForMaskedLM
),
(
AlbertConfig
,
TFAlbertForPreTraining
),
...
...
utils/check_repo.py
View file @
a39218b7
...
...
@@ -70,6 +70,34 @@ MODEL_NAME_TO_DOC_FILE = {
"marian"
:
"marian.rst"
,
}
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED
=
[
"DPRContextEncoder"
,
"DPREncoder"
,
"DPRReader"
,
"DPRSpanPredictor"
,
"FlaubertForQuestionAnswering"
,
"FunnelBaseModel"
,
"GPT2DoubleHeadsModel"
,
"OpenAIGPTDoubleHeadsModel"
,
"ProphetNetDecoder"
,
"ProphetNetEncoder"
,
"RagModel"
,
"RagSequenceForGeneration"
,
"RagTokenForGeneration"
,
"T5Stack"
,
"TFBertForNextSentencePrediction"
,
"TFFunnelBaseModel"
,
"TFGPT2DoubleHeadsModel"
,
"TFMobileBertForNextSentencePrediction"
,
"TFOpenAIGPTDoubleHeadsModel"
,
"XLMForQuestionAnswering"
,
"XLMProphetNetDecoder"
,
"XLMProphetNetEncoder"
,
"XLNetForQuestionAnswering"
,
]
# This is to make sure the transformers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
...
...
@@ -282,6 +310,45 @@ def check_all_models_are_documented():
raise
Exception
(
f
"There were
{
len
(
failures
)
}
failures:
\n
"
+
"
\n
"
.
join
(
failures
))
def
get_all_auto_configured_models
():
""" Return the list of all models in at least one auto class."""
result
=
set
()
# To avoid duplicates we concatenate all model classes in a set.
for
attr_name
in
dir
(
transformers
.
modeling_auto
):
if
attr_name
.
startswith
(
"MODEL_"
)
and
attr_name
.
endswith
(
"MAPPING"
):
result
=
result
|
set
(
getattr
(
transformers
.
modeling_auto
,
attr_name
).
values
())
for
attr_name
in
dir
(
transformers
.
modeling_tf_auto
):
if
attr_name
.
startswith
(
"TF_MODEL_"
)
and
attr_name
.
endswith
(
"MAPPING"
):
result
=
result
|
set
(
getattr
(
transformers
.
modeling_tf_auto
,
attr_name
).
values
())
return
[
cls
.
__name__
for
cls
in
result
]
def
check_models_are_auto_configured
(
module
,
all_auto_models
):
""" Check models defined in module are each in an auto class."""
defined_models
=
get_models
(
module
)
failures
=
[]
for
model_name
,
_
in
defined_models
:
if
model_name
not
in
all_auto_models
and
model_name
not
in
IGNORE_NON_AUTO_CONFIGURED
:
failures
.
append
(
f
"
{
model_name
}
is defined in
{
module
.
__name__
}
but is not present in any of the auto mapping. "
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
"`utils/check_repo.py`."
)
return
failures
def
check_all_models_are_auto_configured
():
""" Check all models are each in an auto class."""
modules
=
get_model_modules
()
all_auto_models
=
get_all_auto_configured_models
()
failures
=
[]
for
module
in
modules
:
new_failures
=
check_models_are_auto_configured
(
module
,
all_auto_models
)
if
new_failures
is
not
None
:
failures
+=
new_failures
if
len
(
failures
)
>
0
:
raise
Exception
(
f
"There were
{
len
(
failures
)
}
failures:
\n
"
+
"
\n
"
.
join
(
failures
))
_re_decorator
=
re
.
compile
(
r
"^\s*@(\S+)\s+$"
)
...
...
@@ -325,6 +392,8 @@ def check_repo_quality():
check_all_models_are_tested
()
print
(
"Checking all models are properly documented."
)
check_all_models_are_documented
()
print
(
"Checking all models are in at least one auto class."
)
check_all_models_are_auto_configured
()
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment