Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
26bdef43
Commit
26bdef43
authored
Nov 04, 2018
by
thomwolf
Browse files
fixing verbose_argument
parent
d6418c5e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
run_squad.py
run_squad.py
+8
-8
No files found.
run_squad.py
View file @
26bdef43
...
@@ -406,7 +406,7 @@ RawResult = collections.namedtuple("RawResult",
...
@@ -406,7 +406,7 @@ RawResult = collections.namedtuple("RawResult",
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
):
output_nbest_file
,
verbose_logging
):
"""Write final predictions to the json file."""
"""Write final predictions to the json file."""
logger
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logger
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logger
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
logger
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
...
@@ -492,7 +492,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -492,7 +492,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text
=
" "
.
join
(
tok_text
.
split
())
tok_text
=
" "
.
join
(
tok_text
.
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
,
verbose_logging
)
if
final_text
in
seen_predictions
:
if
final_text
in
seen_predictions
:
continue
continue
...
@@ -538,7 +538,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -538,7 +538,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
):
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
"""Project the tokenized prediction back to the original text."""
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# When we created the data, we kept track of the alignment between original
...
@@ -587,7 +587,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -587,7 +587,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
start_position
=
tok_text
.
find
(
pred_text
)
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
start_position
==
-
1
:
if
args
.
verbose_logging
:
if
verbose_logging
:
logger
.
info
(
logger
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
return
orig_text
...
@@ -597,7 +597,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -597,7 +597,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
args
.
verbose_logging
:
if
verbose_logging
:
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
logger
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
orig_ns_text
,
tok_ns_text
)
return
orig_text
return
orig_text
...
@@ -615,7 +615,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -615,7 +615,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
if
orig_start_position
is
None
:
if
orig_start_position
is
None
:
if
args
.
verbose_logging
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map start position"
)
logger
.
info
(
"Couldn't map start position"
)
return
orig_text
return
orig_text
...
@@ -626,7 +626,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -626,7 +626,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
if
orig_end_position
is
None
:
if
orig_end_position
is
None
:
if
args
.
verbose_logging
:
if
verbose_logging
:
logger
.
info
(
"Couldn't map end position"
)
logger
.
info
(
"Couldn't map end position"
)
return
orig_text
return
orig_text
...
@@ -949,7 +949,7 @@ def main():
...
@@ -949,7 +949,7 @@ def main():
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
)
output_nbest_file
,
args
.
verbose_logging
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment