Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b0d49fd5
Unverified
Commit
b0d49fd5
authored
Apr 04, 2021
by
Sylvain Gugger
Committed by
GitHub
Apr 04, 2021
Browse files
Add a script to check inits are consistent (#11024)
parent
335c0ca3
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
237 additions
and
5 deletions
+237
-5
.circleci/config.yml
.circleci/config.yml
+1
-0
Makefile
Makefile
+1
-0
src/transformers/__init__.py
src/transformers/__init__.py
+8
-0
src/transformers/models/gpt_neo/__init__.py
src/transformers/models/gpt_neo/__init__.py
+1
-5
src/transformers/models/mt5/__init__.py
src/transformers/models/mt5/__init__.py
+6
-0
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+29
-0
utils/check_inits.py
utils/check_inits.py
+191
-0
No files found.
.circleci/config.yml
View file @
b0d49fd5
...
...
@@ -405,6 +405,7 @@ jobs:
-
run
:
python utils/check_table.py
-
run
:
python utils/check_dummies.py
-
run
:
python utils/check_repo.py
-
run
:
python utils/check_inits.py
check_repository_consistency
:
working_directory
:
~/transformers
...
...
Makefile
View file @
b0d49fd5
...
...
@@ -31,6 +31,7 @@ extra_quality_checks:
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_inits.py
# this target runs checks on all files
quality
:
...
...
src/transformers/__init__.py
View file @
b0d49fd5
...
...
@@ -1552,6 +1552,7 @@ if TYPE_CHECKING:
from
.training_args
import
TrainingArguments
from
.training_args_seq2seq
import
Seq2SeqTrainingArguments
from
.training_args_tf
import
TFTrainingArguments
from
.utils
import
logging
if
is_sentencepiece_available
():
from
.models.albert
import
AlbertTokenizer
...
...
@@ -1662,6 +1663,12 @@ if TYPE_CHECKING:
TopKLogitsWarper
,
TopPLogitsWarper
,
)
from
.generation_stopping_criteria
import
(
MaxLengthCriteria
,
MaxTimeCriteria
,
StoppingCriteria
,
StoppingCriteriaList
,
)
from
.generation_utils
import
top_k_top_p_filtering
from
.modeling_utils
import
Conv1D
,
PreTrainedModel
,
apply_chunking_to_forward
,
prune_layer
from
.models.albert
import
(
...
...
@@ -1887,6 +1894,7 @@ if TYPE_CHECKING:
IBertForSequenceClassification
,
IBertForTokenClassification
,
IBertModel
,
IBertPreTrainedModel
,
)
from
.models.layoutlm
import
(
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
,
...
...
src/transformers/models/gpt_neo/__init__.py
View file @
b0d49fd5
...
...
@@ -17,17 +17,13 @@
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
...file_utils
import
_BaseLazyModule
,
is_tokenizers_available
,
is_torch_available
from
...file_utils
import
_BaseLazyModule
,
is_torch_available
_import_structure
=
{
"configuration_gpt_neo"
:
[
"GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"GPTNeoConfig"
],
"tokenization_gpt_neo"
:
[
"GPTNeoTokenizer"
],
}
if
is_tokenizers_available
():
_import_structure
[
"tokenization_gpt_neo_fast"
]
=
[
"GPTNeoTokenizerFast"
]
if
is_torch_available
():
_import_structure
[
"modeling_gpt_neo"
]
=
[
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST"
,
...
...
src/transformers/models/mt5/__init__.py
View file @
b0d49fd5
...
...
@@ -41,6 +41,12 @@ _import_structure = {
"configuration_mt5"
:
[
"MT5Config"
],
}
if
is_sentencepiece_available
():
_import_structure
[
"."
]
=
[
"T5Tokenizer"
]
# Fake to get the same objects in both side.
if
is_tokenizers_available
():
_import_structure
[
"."
]
=
[
"T5TokenizerFast"
]
# Fake to get the same objects in both side.
if
is_torch_available
():
_import_structure
[
"modeling_mt5"
]
=
[
"MT5EncoderModel"
,
"MT5ForConditionalGeneration"
,
"MT5Model"
]
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
b0d49fd5
...
...
@@ -198,6 +198,26 @@ class TopPLogitsWarper:
requires_pytorch
(
self
)
class
MaxLengthCriteria
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_pytorch
(
self
)
class
MaxTimeCriteria
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_pytorch
(
self
)
class
StoppingCriteria
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_pytorch
(
self
)
class
StoppingCriteriaList
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_pytorch
(
self
)
def
top_k_top_p_filtering
(
*
args
,
**
kwargs
):
requires_pytorch
(
top_k_top_p_filtering
)
...
...
@@ -1539,6 +1559,15 @@ class IBertModel:
requires_pytorch
(
self
)
class
IBertPreTrainedModel
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_pytorch
(
self
)
@
classmethod
def
from_pretrained
(
self
,
*
args
,
**
kwargs
):
requires_pytorch
(
self
)
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
...
...
utils/check_inits.py
0 → 100644
View file @
b0d49fd5
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
PATH_TO_TRANSFORMERS
=
"src/transformers"
BACKENDS
=
[
"torch"
,
"tf"
,
"flax"
,
"sentencepiece"
,
"tokenizers"
,
"vision"
]
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value
=
re
.
compile
(
r
'\s+"\S*":\s+\[([^\]]*)\]'
)
# Catches a line if is_foo_available
_re_test_backend
=
re
.
compile
(
r
"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$"
)
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one
=
re
.
compile
(
r
'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)'
)
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
_re_import_struct_add_many
=
re
.
compile
(
r
"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]"
)
# Catches a line with an object between quotes and a comma: "MyModel",
_re_quote_object
=
re
.
compile
(
'^\s+"([^"]+)",'
)
# Catches a line with objects between brackets only: ["foo", "bar"],
_re_between_brackets
=
re
.
compile
(
"^\s+\[([^\]]+)\]"
)
# Catches a line with from foo import bar, bla, boo
_re_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
def
parse_init
(
init_file
):
"""
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
defined
"""
with
open
(
init_file
,
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
lines
=
f
.
readlines
()
line_index
=
0
while
line_index
<
len
(
lines
)
and
not
lines
[
line_index
].
startswith
(
"_import_structure = {"
):
line_index
+=
1
# If this is a traditional init, just return.
if
line_index
>=
len
(
lines
):
return
None
# First grab the objects without a specific backend in _import_structure
objects
=
[]
while
not
lines
[
line_index
].
startswith
(
"if TYPE_CHECKING"
)
and
_re_test_backend
.
search
(
lines
[
line_index
])
is
None
:
line
=
lines
[
line_index
]
single_line_import_search
=
_re_import_struct_key_value
.
search
(
line
)
if
single_line_import_search
is
not
None
:
imports
=
[
obj
[
1
:
-
1
]
for
obj
in
single_line_import_search
.
groups
()[
0
].
split
(
", "
)
if
len
(
obj
)
>
0
]
objects
.
extend
(
imports
)
elif
line
.
startswith
(
" "
*
8
+
'"'
):
objects
.
append
(
line
[
9
:
-
3
])
line_index
+=
1
import_dict_objects
=
{
"none"
:
objects
}
# Let's continue with backend-specific objects in _import_structure
while
not
lines
[
line_index
].
startswith
(
"if TYPE_CHECKING"
):
# If the line is an if is_backend_available, we grab all objects associated.
if
_re_test_backend
.
search
(
lines
[
line_index
])
is
not
None
:
backend
=
_re_test_backend
.
search
(
lines
[
line_index
]).
groups
()[
0
]
line_index
+=
1
# Ignore if backend isn't tracked for dummies.
if
backend
not
in
BACKENDS
:
continue
objects
=
[]
# Until we unindent, add backend objects to the list
while
len
(
lines
[
line_index
])
<=
1
or
lines
[
line_index
].
startswith
(
" "
*
4
):
line
=
lines
[
line_index
]
if
_re_import_struct_add_one
.
search
(
line
)
is
not
None
:
objects
.
append
(
_re_import_struct_add_one
.
search
(
line
).
groups
()[
0
])
elif
_re_import_struct_add_many
.
search
(
line
)
is
not
None
:
imports
=
_re_import_struct_add_many
.
search
(
line
).
groups
()[
0
].
split
(
", "
)
imports
=
[
obj
[
1
:
-
1
]
for
obj
in
imports
if
len
(
obj
)
>
0
]
objects
.
extend
(
imports
)
elif
_re_between_brackets
.
search
(
line
)
is
not
None
:
imports
=
_re_between_brackets
.
search
(
line
).
groups
()[
0
].
split
(
", "
)
imports
=
[
obj
[
1
:
-
1
]
for
obj
in
imports
if
len
(
obj
)
>
0
]
objects
.
extend
(
imports
)
elif
_re_quote_object
.
search
(
line
)
is
not
None
:
objects
.
append
(
_re_quote_object
.
search
(
line
).
groups
()[
0
])
elif
line
.
startswith
(
" "
*
8
+
'"'
):
objects
.
append
(
line
[
9
:
-
3
])
elif
line
.
startswith
(
" "
*
12
+
'"'
):
objects
.
append
(
line
[
13
:
-
3
])
line_index
+=
1
import_dict_objects
[
backend
]
=
objects
else
:
line_index
+=
1
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
objects
=
[]
while
(
line_index
<
len
(
lines
)
and
_re_test_backend
.
search
(
lines
[
line_index
])
is
None
and
not
lines
[
line_index
].
startswith
(
"else"
)
):
line
=
lines
[
line_index
]
single_line_import_search
=
_re_import
.
search
(
line
)
if
single_line_import_search
is
not
None
:
objects
.
extend
(
single_line_import_search
.
groups
()[
0
].
split
(
", "
))
elif
line
.
startswith
(
" "
*
8
):
objects
.
append
(
line
[
8
:
-
2
])
line_index
+=
1
type_hint_objects
=
{
"none"
:
objects
}
# Let's continue with backend-specific objects
while
line_index
<
len
(
lines
):
# If the line is an if is_backemd_available, we grab all objects associated.
if
_re_test_backend
.
search
(
lines
[
line_index
])
is
not
None
:
backend
=
_re_test_backend
.
search
(
lines
[
line_index
]).
groups
()[
0
]
line_index
+=
1
# Ignore if backend isn't tracked for dummies.
if
backend
not
in
BACKENDS
:
continue
objects
=
[]
# Until we unindent, add backend objects to the list
while
len
(
lines
[
line_index
])
<=
1
or
lines
[
line_index
].
startswith
(
" "
*
8
):
line
=
lines
[
line_index
]
single_line_import_search
=
_re_import
.
search
(
line
)
if
single_line_import_search
is
not
None
:
objects
.
extend
(
single_line_import_search
.
groups
()[
0
].
split
(
", "
))
elif
line
.
startswith
(
" "
*
12
):
objects
.
append
(
line
[
12
:
-
2
])
line_index
+=
1
type_hint_objects
[
backend
]
=
objects
else
:
line_index
+=
1
return
import_dict_objects
,
type_hint_objects
def
analyze_results
(
import_dict_objects
,
type_hint_objects
):
"""
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
"""
if
list
(
import_dict_objects
.
keys
())
!=
list
(
type_hint_objects
.
keys
()):
return
[
"Both sides of the init do not have the same backends!"
]
errors
=
[]
for
key
in
import_dict_objects
.
keys
():
if
sorted
(
import_dict_objects
[
key
])
!=
sorted
(
type_hint_objects
[
key
]):
name
=
"base imports"
if
key
==
"none"
else
f
"
{
key
}
backend"
errors
.
append
(
f
"Differences for
{
name
}
:"
)
for
a
in
type_hint_objects
[
key
]:
if
a
not
in
import_dict_objects
[
key
]:
errors
.
append
(
f
"
{
a
}
in TYPE_HINT but not in _import_structure."
)
for
a
in
import_dict_objects
[
key
]:
if
a
not
in
type_hint_objects
[
key
]:
errors
.
append
(
f
"
{
a
}
in _import_structure but not in TYPE_HINT."
)
return
errors
def
check_all_inits
():
"""
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in
both halves.
"""
failures
=
[]
for
root
,
_
,
files
in
os
.
walk
(
PATH_TO_TRANSFORMERS
):
if
"__init__.py"
in
files
:
fname
=
os
.
path
.
join
(
root
,
"__init__.py"
)
objects
=
parse_init
(
fname
)
if
objects
is
not
None
:
errors
=
analyze_results
(
*
objects
)
if
len
(
errors
)
>
0
:
errors
[
0
]
=
f
"Problem in
{
fname
}
, both halves do not define the same objects.
\n
{
errors
[
0
]
}
"
failures
.
append
(
"
\n
"
.
join
(
errors
))
if
len
(
failures
)
>
0
:
raise
ValueError
(
"
\n\n
"
.
join
(
failures
))
if
__name__
==
"__main__"
:
check_all_inits
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment