Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
57461ac0
Unverified
Commit
57461ac0
authored
Jun 28, 2021
by
Sylvain Gugger
Committed by
GitHub
Jun 28, 2021
Browse files
Add possibility to maintain full copies of files (#12312)
parent
9490d668
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
0 deletions
+30
-0
examples/tensorflow/question-answering/utils_qa.py
examples/tensorflow/question-answering/utils_qa.py
+2
-0
utils/check_copies.py
utils/check_copies.py
+28
-0
No files found.
examples/tensorflow/question-answering/utils_qa.py
View file @
57461ac0
...
@@ -38,6 +38,7 @@ def postprocess_qa_predictions(
...
@@ -38,6 +38,7 @@ def postprocess_qa_predictions(
null_score_diff_threshold
:
float
=
0.0
,
null_score_diff_threshold
:
float
=
0.0
,
output_dir
:
Optional
[
str
]
=
None
,
output_dir
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
is_world_process_zero
:
bool
=
True
,
):
):
"""
"""
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
...
@@ -90,6 +91,7 @@ def postprocess_qa_predictions(
...
@@ -90,6 +91,7 @@ def postprocess_qa_predictions(
scores_diff_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
# Logging.
# Logging.
logger
.
setLevel
(
logging
.
INFO
if
is_world_process_zero
else
logging
.
WARN
)
logger
.
info
(
f
"Post-processing
{
len
(
examples
)
}
example predictions split into
{
len
(
features
)
}
features."
)
logger
.
info
(
f
"Post-processing
{
len
(
examples
)
}
example predictions split into
{
len
(
features
)
}
features."
)
# Let's loop over all the examples!
# Let's loop over all the examples!
...
...
utils/check_copies.py
View file @
57461ac0
...
@@ -27,6 +27,9 @@ TRANSFORMERS_PATH = "src/transformers"
...
@@ -27,6 +27,9 @@ TRANSFORMERS_PATH = "src/transformers"
PATH_TO_DOCS
=
"docs/source"
PATH_TO_DOCS
=
"docs/source"
REPO_PATH
=
"."
REPO_PATH
=
"."
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
FULL_COPIES
=
{
"examples/tensorflow/question-answering/utils_qa.py"
:
"examples/pytorch/question-answering/utils_qa.py"
}
def
_should_continue
(
line
,
indent
):
def
_should_continue
(
line
,
indent
):
return
line
.
startswith
(
indent
)
or
len
(
line
)
<=
1
or
re
.
search
(
r
"^\s*\):\s*$"
,
line
)
is
not
None
return
line
.
startswith
(
indent
)
or
len
(
line
)
<=
1
or
re
.
search
(
r
"^\s*\):\s*$"
,
line
)
is
not
None
...
@@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False):
...
@@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False):
check_model_list_copy
(
overwrite
=
overwrite
)
check_model_list_copy
(
overwrite
=
overwrite
)
def
check_full_copies
(
overwrite
:
bool
=
False
):
diffs
=
[]
for
target
,
source
in
FULL_COPIES
.
items
():
with
open
(
source
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
source_code
=
f
.
read
()
with
open
(
target
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
target_code
=
f
.
read
()
if
source_code
!=
target_code
:
if
overwrite
:
with
open
(
target
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
print
(
f
"Replacing the content of
{
target
}
by the one of
{
source
}
."
)
f
.
write
(
source_code
)
else
:
diffs
.
append
(
f
"-
{
target
}
: copy does not match
{
source
}
."
)
if
not
overwrite
and
len
(
diffs
)
>
0
:
diff
=
"
\n
"
.
join
(
diffs
)
raise
Exception
(
"Found the following copy inconsistencies:
\n
"
+
diff
+
"
\n
Run `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
)
def
get_model_list
():
def
get_model_list
():
"""Extracts the model list from the README."""
"""Extracts the model list from the README."""
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
...
@@ -324,3 +351,4 @@ if __name__ == "__main__":
...
@@ -324,3 +351,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
check_copies
(
args
.
fix_and_overwrite
)
check_copies
(
args
.
fix_and_overwrite
)
check_full_copies
(
args
.
fix_and_overwrite
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment