Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
276bc149
Commit
276bc149
authored
Jun 28, 2021
by
Sylvain Gugger
Browse files
Fix copies
parent
27b6ac46
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
examples/tensorflow/question-answering/utils_qa.py
examples/tensorflow/question-answering/utils_qa.py
+11
-11
No files found.
examples/tensorflow/question-answering/utils_qa.py
View file @
276bc149
...
...
@@ -38,7 +38,7 @@ def postprocess_qa_predictions(
null_score_diff_threshold
:
float
=
0.0
,
output_dir
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
is_world_process_zero
:
bool
=
True
,
log_level
:
Optional
[
int
]
=
logging
.
WARNING
,
):
"""
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
...
...
@@ -70,8 +70,8 @@ def postprocess_qa_predictions(
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
is_world_process_zero
(:obj:`
bool
`, `optional`, defaults to
:obj:`True
`):
Whether this process is the main process or not (used to determine if logging/saves should be done).
log_level
(:obj:`
int
`, `optional`, defaults to
``logging.WARNING`
`):
``logging`` log level (e.g., ``logging.WARNING``)
"""
assert
len
(
predictions
)
==
2
,
"`predictions` should be a tuple with two elements (start_logits, end_logits)."
all_start_logits
,
all_end_logits
=
predictions
...
...
@@ -91,7 +91,7 @@ def postprocess_qa_predictions(
scores_diff_json
=
collections
.
OrderedDict
()
# Logging.
logger
.
setLevel
(
log
ging
.
INFO
if
is_world_process_zero
else
logging
.
WARN
)
logger
.
setLevel
(
log
_level
)
logger
.
info
(
f
"Post-processing
{
len
(
examples
)
}
example predictions split into
{
len
(
features
)
}
features."
)
# Let's loop over all the examples!
...
...
@@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search(
end_n_top
:
int
=
5
,
output_dir
:
Optional
[
str
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
is_world_process_zero
:
bool
=
True
,
log_level
:
Optional
[
int
]
=
logging
.
WARNING
,
):
"""
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
...
...
@@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search(
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
is_world_process_zero
(:obj:`
bool
`, `optional`, defaults to
:obj:`True
`):
Whether this process is the main process or not (used to determine if logging/saves should be done).
log_level
(:obj:`
int
`, `optional`, defaults to
``logging.WARNING`
`):
``logging`` log level (e.g., ``logging.WARNING``)
"""
assert
len
(
predictions
)
==
5
,
"`predictions` should be a tuple with five elements."
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
=
predictions
...
...
@@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search(
scores_diff_json
=
collections
.
OrderedDict
()
if
version_2_with_negative
else
None
# Logging.
logger
.
setLevel
(
log
ging
.
INFO
if
is_world_process_zero
else
logging
.
WARN
)
logger
.
setLevel
(
log
_level
)
logger
.
info
(
f
"Post-processing
{
len
(
examples
)
}
example predictions split into
{
len
(
features
)
}
features."
)
# Let's loop over all the examples!
...
...
@@ -413,14 +413,14 @@ def postprocess_qa_predictions_with_beam_search(
output_dir
,
"null_odds.json"
if
prefix
is
None
else
f
"
{
prefix
}
_null_odds.json"
)
print
(
f
"Saving predictions to
{
prediction_file
}
."
)
logger
.
info
(
f
"Saving predictions to
{
prediction_file
}
."
)
with
open
(
prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
print
(
f
"Saving nbest_preds to
{
nbest_file
}
."
)
logger
.
info
(
f
"Saving nbest_preds to
{
nbest_file
}
."
)
with
open
(
nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
print
(
f
"Saving null_odds to
{
null_odds_file
}
."
)
logger
.
info
(
f
"Saving null_odds to
{
null_odds_file
}
."
)
with
open
(
null_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment