Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d609ba24
Commit
d609ba24
authored
Feb 05, 2019
by
thomwolf
Browse files
resolving merge conflicts
parent
64ce9009
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
159 additions
and
1145 deletions
+159
-1145
examples/run_squad.py
examples/run_squad.py
+159
-74
examples/run_squad2.py
examples/run_squad2.py
+0
-1071
No files found.
examples/run_squad.py
View file @
d609ba24
...
@@ -46,7 +46,10 @@ logger = logging.getLogger(__name__)
...
@@ -46,7 +46,10 @@ logger = logging.getLogger(__name__)
class
SquadExample
(
object
):
class
SquadExample
(
object
):
"""A single training/test example for the Squad dataset."""
"""
A single training/test example for the Squad dataset.
For examples without an answer, the start and end position are -1.
"""
def
__init__
(
self
,
def
__init__
(
self
,
qas_id
,
qas_id
,
...
@@ -54,13 +57,15 @@ class SquadExample(object):
...
@@ -54,13 +57,15 @@ class SquadExample(object):
doc_tokens
,
doc_tokens
,
orig_answer_text
=
None
,
orig_answer_text
=
None
,
start_position
=
None
,
start_position
=
None
,
end_position
=
None
):
end_position
=
None
,
is_impossible
=
None
):
self
.
qas_id
=
qas_id
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
question_text
=
question_text
self
.
doc_tokens
=
doc_tokens
self
.
doc_tokens
=
doc_tokens
self
.
orig_answer_text
=
orig_answer_text
self
.
orig_answer_text
=
orig_answer_text
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__repr__
()
return
self
.
__repr__
()
...
@@ -75,6 +80,8 @@ class SquadExample(object):
...
@@ -75,6 +80,8 @@ class SquadExample(object):
s
+=
", start_position: %d"
%
(
self
.
start_position
)
s
+=
", start_position: %d"
%
(
self
.
start_position
)
if
self
.
start_position
:
if
self
.
start_position
:
s
+=
", end_position: %d"
%
(
self
.
end_position
)
s
+=
", end_position: %d"
%
(
self
.
end_position
)
if
self
.
start_position
:
s
+=
", is_impossible: %r"
%
(
self
.
is_impossible
)
return
s
return
s
...
@@ -92,7 +99,8 @@ class InputFeatures(object):
...
@@ -92,7 +99,8 @@ class InputFeatures(object):
input_mask
,
input_mask
,
segment_ids
,
segment_ids
,
start_position
=
None
,
start_position
=
None
,
end_position
=
None
):
end_position
=
None
,
is_impossible
=
None
):
self
.
unique_id
=
unique_id
self
.
unique_id
=
unique_id
self
.
example_index
=
example_index
self
.
example_index
=
example_index
self
.
doc_span_index
=
doc_span_index
self
.
doc_span_index
=
doc_span_index
...
@@ -104,9 +112,10 @@ class InputFeatures(object):
...
@@ -104,9 +112,10 @@ class InputFeatures(object):
self
.
segment_ids
=
segment_ids
self
.
segment_ids
=
segment_ids
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
read_squad_examples
(
input_file
,
is_training
):
def
read_squad_examples
(
input_file
,
is_training
,
version_2_with_negative
):
"""Read a SQuAD json file into a list of SquadExample."""
"""Read a SQuAD json file into a list of SquadExample."""
with
open
(
input_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
input_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
...
@@ -140,10 +149,14 @@ def read_squad_examples(input_file, is_training):
...
@@ -140,10 +149,14 @@ def read_squad_examples(input_file, is_training):
start_position
=
None
start_position
=
None
end_position
=
None
end_position
=
None
orig_answer_text
=
None
orig_answer_text
=
None
is_impossible
=
False
if
is_training
:
if
is_training
:
if
len
(
qa
[
"answers"
])
!=
1
:
if
version_2_with_negative
:
is_impossible
=
qa
[
"is_impossible"
]
if
(
len
(
qa
[
"answers"
])
!=
1
)
and
(
not
is_impossible
):
raise
ValueError
(
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
"For training, each question should have exactly 1 answer."
)
if
not
is_impossible
:
answer
=
qa
[
"answers"
][
0
]
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
orig_answer_text
=
answer
[
"text"
]
answer_offset
=
answer
[
"answer_start"
]
answer_offset
=
answer
[
"answer_start"
]
...
@@ -163,6 +176,10 @@ def read_squad_examples(input_file, is_training):
...
@@ -163,6 +176,10 @@ def read_squad_examples(input_file, is_training):
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
actual_text
,
cleaned_answer_text
)
continue
continue
else
:
start_position
=
-
1
end_position
=
-
1
orig_answer_text
=
""
example
=
SquadExample
(
example
=
SquadExample
(
qas_id
=
qas_id
,
qas_id
=
qas_id
,
...
@@ -170,7 +187,8 @@ def read_squad_examples(input_file, is_training):
...
@@ -170,7 +187,8 @@ def read_squad_examples(input_file, is_training):
doc_tokens
=
doc_tokens
,
doc_tokens
=
doc_tokens
,
orig_answer_text
=
orig_answer_text
,
orig_answer_text
=
orig_answer_text
,
start_position
=
start_position
,
start_position
=
start_position
,
end_position
=
end_position
)
end_position
=
end_position
,
is_impossible
=
is_impossible
)
examples
.
append
(
example
)
examples
.
append
(
example
)
return
examples
return
examples
...
@@ -200,7 +218,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -200,7 +218,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_start_position
=
None
tok_start_position
=
None
tok_end_position
=
None
tok_end_position
=
None
if
is_training
:
if
is_training
and
example
.
is_impossible
:
tok_start_position
=
-
1
tok_end_position
=
-
1
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
...
@@ -272,20 +293,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -272,20 +293,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position
=
None
start_position
=
None
end_position
=
None
end_position
=
None
if
is_training
:
if
is_training
and
not
example
.
is_impossible
:
# For training, if our document chunk does not contain an annotation
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
(
example
.
start_position
<
doc_start
or
out_of_span
=
False
example
.
end_position
<
doc_start
or
if
not
(
tok_start_position
>=
doc_start
and
example
.
start_position
>
doc_end
or
example
.
end_position
>
doc_end
):
tok_end_position
<=
doc_end
):
continue
out_of_span
=
True
if
out_of_span
:
start_position
=
0
end_position
=
0
else
:
doc_offset
=
len
(
query_tokens
)
+
2
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
is_training
and
example
.
is_impossible
:
start_position
=
0
end_position
=
0
if
example_index
<
20
:
if
example_index
<
20
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"unique_id: %s"
%
(
unique_id
))
logger
.
info
(
"unique_id: %s"
%
(
unique_id
))
...
@@ -302,7 +328,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -302,7 +328,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logger
.
info
(
logger
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
:
if
is_training
and
example
.
is_impossible
:
logger
.
info
(
"impossible example"
)
if
is_training
and
not
example
.
is_impossible
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
logger
.
info
(
"start_position: %d"
%
(
start_position
))
logger
.
info
(
"start_position: %d"
%
(
start_position
))
logger
.
info
(
"end_position: %d"
%
(
end_position
))
logger
.
info
(
"end_position: %d"
%
(
end_position
))
...
@@ -321,7 +349,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -321,7 +349,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
start_position
=
start_position
,
start_position
=
start_position
,
end_position
=
end_position
))
end_position
=
end_position
,
is_impossible
=
example
.
is_impossible
))
unique_id
+=
1
unique_id
+=
1
return
features
return
features
...
@@ -401,15 +430,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
...
@@ -401,15 +430,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
return
cur_span_index
==
best_span_index
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
verbose_logging
):
output_nbest_file
,
output_null_log_odds_file
,
verbose_logging
,
"""Write final predictions to the json file."""
version_2_with_negative
,
null_score_diff_threshold
):
"""Write final predictions to the json file and log-odds of null if needed."""
logger
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logger
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logger
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
logger
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
...
@@ -427,15 +456,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -427,15 +456,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions
=
collections
.
OrderedDict
()
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
features
=
example_index_to_features
[
example_index
]
features
=
example_index_to_features
[
example_index
]
prelim_predictions
=
[]
prelim_predictions
=
[]
# keep track of the minimum score of null start+end of position 0
score_null
=
1000000
# large and positive
min_null_feature_index
=
0
# the paragraph slice with min mull score
null_start_logit
=
0
# the start logit at the slice with min null score
null_end_logit
=
0
# the end logit at the slice with min null score
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
# if we could have irrelevant answers, get the min score of irrelevant
if
version_2_with_negative
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
feature_null_score
<
score_null
:
score_null
=
feature_null_score
min_null_feature_index
=
feature_index
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
for
start_index
in
start_indexes
:
for
start_index
in
start_indexes
:
for
end_index
in
end_indexes
:
for
end_index
in
end_indexes
:
# We could hypothetically create invalid predictions, e.g., predict
# We could hypothetically create invalid predictions, e.g., predict
...
@@ -463,7 +506,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -463,7 +506,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
end_index
=
end_index
,
end_index
=
end_index
,
start_logit
=
result
.
start_logits
[
start_index
],
start_logit
=
result
.
start_logits
[
start_index
],
end_logit
=
result
.
end_logits
[
end_index
]))
end_logit
=
result
.
end_logits
[
end_index
]))
if
version_2_with_negative
:
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
min_null_feature_index
,
start_index
=
0
,
end_index
=
0
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
prelim_predictions
=
sorted
(
prelim_predictions
=
sorted
(
prelim_predictions
,
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
...
@@ -478,7 +528,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -478,7 +528,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
if
len
(
nbest
)
>=
n_best_size
:
if
len
(
nbest
)
>=
n_best_size
:
break
break
feature
=
features
[
pred
.
feature_index
]
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>
0
:
# this is a non-null prediction
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
...
@@ -499,12 +549,23 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -499,12 +549,23 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
continue
continue
seen_predictions
[
final_text
]
=
True
seen_predictions
[
final_text
]
=
True
else
:
final_text
=
""
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
nbest
.
append
(
_NbestPrediction
(
_NbestPrediction
(
text
=
final_text
,
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
end_logit
=
pred
.
end_logit
))
# if we didn't include the empty option in the n-best, include it
if
version_2_with_negative
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
# In very rare edge cases we could have no valid predictions. So we
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
if
not
nbest
:
...
@@ -514,8 +575,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -514,8 +575,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert
len
(
nbest
)
>=
1
assert
len
(
nbest
)
>=
1
total_scores
=
[]
total_scores
=
[]
best_non_null_entry
=
None
for
entry
in
nbest
:
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
if
not
best_non_null_entry
:
if
entry
.
text
:
best_non_null_entry
=
entry
probs
=
_compute_softmax
(
total_scores
)
probs
=
_compute_softmax
(
total_scores
)
...
@@ -530,7 +595,17 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -530,7 +595,17 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert
len
(
nbest_json
)
>=
1
assert
len
(
nbest_json
)
>=
1
if
not
version_2_with_negative
:
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
else
:
# predict "" iff the null score - the score of best non-null > threshold
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
else
:
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
open
(
output_prediction_file
,
"w"
)
as
writer
:
with
open
(
output_prediction_file
,
"w"
)
as
writer
:
...
@@ -539,6 +614,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -539,6 +614,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
with
open
(
output_nbest_file
,
"w"
)
as
writer
:
with
open
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
version_2_with_negative
:
with
open
(
output_null_log_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
,
verbose_logging
=
False
):
"""Project the tokenized prediction back to the original text."""
"""Project the tokenized prediction back to the original text."""
...
@@ -701,7 +780,7 @@ def main():
...
@@ -701,7 +780,7 @@ def main():
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3.0
,
type
=
float
,
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3.0
,
type
=
float
,
help
=
"Total number of training epochs to perform."
)
help
=
"Total number of training epochs to perform."
)
parser
.
add_argument
(
"--warmup_proportion"
,
default
=
0.1
,
type
=
float
,
parser
.
add_argument
(
"--warmup_proportion"
,
default
=
0.1
,
type
=
float
,
help
=
"Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
help
=
"Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%
%
"
"of training."
)
"of training."
)
parser
.
add_argument
(
"--n_best_size"
,
default
=
20
,
type
=
int
,
parser
.
add_argument
(
"--n_best_size"
,
default
=
20
,
type
=
int
,
help
=
"The total number of n-best predictions to generate in the nbest_predictions.json "
help
=
"The total number of n-best predictions to generate in the nbest_predictions.json "
...
@@ -738,7 +817,12 @@ def main():
...
@@ -738,7 +817,12 @@ def main():
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"Positive power of 2: static loss scaling value.
\n
"
)
"Positive power of 2: static loss scaling value.
\n
"
)
parser
.
add_argument
(
'--version_2_with_negative'
,
action
=
'store_true'
,
help
=
'If true, the SQuAD examples contain some that do not have an answer.'
)
parser
.
add_argument
(
'--null_score_diff_threshold'
,
type
=
float
,
default
=
0.0
,
help
=
"If null_score - best_non_null is greater than the threshold predict null."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
local_rank
==
-
1
or
args
.
no_cuda
:
if
args
.
local_rank
==
-
1
or
args
.
no_cuda
:
...
@@ -787,9 +871,9 @@ def main():
...
@@ -787,9 +871,9 @@ def main():
num_train_optimization_steps
=
None
num_train_optimization_steps
=
None
if
args
.
do_train
:
if
args
.
do_train
:
train_examples
=
read_squad_examples
(
train_examples
=
read_squad_examples
(
input_file
=
args
.
train_file
,
is_training
=
True
)
input_file
=
args
.
train_file
,
is_training
=
True
,
version_2_with_negative
=
args
.
version_2_with_negative
)
num_train_optimization_steps
=
int
(
num_train_optimization_steps
=
int
(
len
(
train_
examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
)
*
args
.
num_train_epochs
len
(
train_
dataset
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
)
*
args
.
num_train_epochs
if
args
.
local_rank
!=
-
1
:
if
args
.
local_rank
!=
-
1
:
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
...
@@ -825,7 +909,7 @@ def main():
...
@@ -825,7 +909,7 @@ def main():
if
args
.
fp16
:
if
args
.
fp16
:
try
:
try
:
from
apex.optimizer
s
import
FP16_Optimizer
from
apex.optimizer
import
FP16_Optimizer
from
apex.optimizers
import
FusedAdam
from
apex.optimizers
import
FusedAdam
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
...
@@ -901,7 +985,7 @@ def main():
...
@@ -901,7 +985,7 @@ def main():
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used
that
handles this automatically
# if args.fp16 is False, BertAdam is used
and
handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
num_train_optimization_steps
,
args
.
warmup_proportion
)
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
num_train_optimization_steps
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
param_group
[
'lr'
]
=
lr_this_step
...
@@ -914,7 +998,6 @@ def main():
...
@@ -914,7 +998,6 @@ def main():
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
if
args
.
do_train
:
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
state_dict
=
model_state_dict
)
...
@@ -925,7 +1008,7 @@ def main():
...
@@ -925,7 +1008,7 @@ def main():
if
args
.
do_predict
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_predict
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
eval_examples
=
read_squad_examples
(
eval_examples
=
read_squad_examples
(
input_file
=
args
.
predict_file
,
is_training
=
False
)
input_file
=
args
.
predict_file
,
is_training
=
False
,
version_2_with_negative
=
args
.
version_2_with_negative
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
examples
=
eval_examples
,
examples
=
eval_examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -969,10 +1052,12 @@ def main():
...
@@ -969,10 +1052,12 @@ def main():
end_logits
=
end_logits
))
end_logits
=
end_logits
))
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions.json"
)
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions.json"
)
output_null_log_odds_file
=
os
.
path
.
join
(
args
.
output_dir
,
"null_odds.json"
)
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
args
.
verbose_logging
)
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/run_squad2.py
deleted
100644 → 0
View file @
64ce9009
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment