Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
31e4a64d
"src/graph/transform/to_block.h" did not exist on "bcd37684268a919f25aa5b9eb88f4e59aca1e7b4"
Commit
31e4a64d
authored
Nov 12, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 12, 2020
Browse files
Internal change
PiperOrigin-RevId: 342129726
parent
82f46dc7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
309 additions
and
106 deletions
+309
-106
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+9
-3
official/nlp/data/squad_lib.py
official/nlp/data/squad_lib.py
+145
-73
official/nlp/data/squad_lib_sp.py
official/nlp/data/squad_lib_sp.py
+3
-8
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+24
-7
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+40
-15
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+88
-0
No files found.
official/nlp/data/create_finetuning_data.py
View file @
31e4a64d
...
@@ -262,9 +262,15 @@ def generate_squad_dataset():
...
@@ -262,9 +262,15 @@ def generate_squad_dataset():
assert
FLAGS
.
squad_data_file
assert
FLAGS
.
squad_data_file
if
FLAGS
.
tokenization
==
"WordPiece"
:
if
FLAGS
.
tokenization
==
"WordPiece"
:
return
squad_lib_wp
.
generate_tf_record_from_json_file
(
return
squad_lib_wp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
vocab_file
,
FLAGS
.
train_data_output_path
,
input_file_path
=
FLAGS
.
squad_data_file
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
vocab_file_path
=
FLAGS
.
vocab_file
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
output_path
=
FLAGS
.
train_data_output_path
,
max_seq_length
=
FLAGS
.
max_seq_length
,
do_lower_case
=
FLAGS
.
do_lower_case
,
max_query_length
=
FLAGS
.
max_query_length
,
doc_stride
=
FLAGS
.
doc_stride
,
version_2_with_negative
=
FLAGS
.
version_2_with_negative
,
xlnet_format
=
FLAGS
.
xlnet_format
)
else
:
else
:
assert
FLAGS
.
tokenization
==
"SentencePiece"
assert
FLAGS
.
tokenization
==
"SentencePiece"
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
...
...
official/nlp/data/squad_lib.py
View file @
31e4a64d
...
@@ -92,6 +92,8 @@ class InputFeatures(object):
...
@@ -92,6 +92,8 @@ class InputFeatures(object):
input_ids
,
input_ids
,
input_mask
,
input_mask
,
segment_ids
,
segment_ids
,
paragraph_mask
=
None
,
class_index
=
None
,
start_position
=
None
,
start_position
=
None
,
end_position
=
None
,
end_position
=
None
,
is_impossible
=
None
):
is_impossible
=
None
):
...
@@ -107,6 +109,8 @@ class InputFeatures(object):
...
@@ -107,6 +109,8 @@ class InputFeatures(object):
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
self
.
is_impossible
=
is_impossible
self
.
paragraph_mask
=
paragraph_mask
self
.
class_index
=
class_index
class
FeatureWriter
(
object
):
class
FeatureWriter
(
object
):
...
@@ -134,6 +138,11 @@ class FeatureWriter(object):
...
@@ -134,6 +138,11 @@ class FeatureWriter(object):
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
if
feature
.
paragraph_mask
is
not
None
:
features
[
"paragraph_mask"
]
=
create_int_feature
(
feature
.
paragraph_mask
)
if
feature
.
class_index
is
not
None
:
features
[
"class_index"
]
=
create_int_feature
([
feature
.
class_index
])
if
self
.
is_training
:
if
self
.
is_training
:
features
[
"start_positions"
]
=
create_int_feature
([
feature
.
start_position
])
features
[
"start_positions"
]
=
create_int_feature
([
feature
.
start_position
])
features
[
"end_positions"
]
=
create_int_feature
([
feature
.
end_position
])
features
[
"end_positions"
]
=
create_int_feature
([
feature
.
end_position
])
...
@@ -238,6 +247,7 @@ def convert_examples_to_features(examples,
...
@@ -238,6 +247,7 @@ def convert_examples_to_features(examples,
max_query_length
,
max_query_length
,
is_training
,
is_training
,
output_fn
,
output_fn
,
xlnet_format
=
False
,
batch_size
=
None
):
batch_size
=
None
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
...
@@ -299,25 +309,54 @@ def convert_examples_to_features(examples,
...
@@ -299,25 +309,54 @@ def convert_examples_to_features(examples,
token_to_orig_map
=
{}
token_to_orig_map
=
{}
token_is_max_context
=
{}
token_is_max_context
=
{}
segment_ids
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
# Paragraph mask used in XLNet.
for
token
in
query_tokens
:
# 1 represents paragraph and class tokens.
tokens
.
append
(
token
)
# 0 represents query and other special tokens.
segment_ids
.
append
(
0
)
paragraph_mask
=
[]
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
# pylint: disable=cell-var-from-loop
def
process_query
(
seg_q
):
for
i
in
range
(
doc_span
.
length
):
for
token
in
query_tokens
:
split_token_index
=
doc_span
.
start
+
i
tokens
.
append
(
token
)
token_to_orig_map
[
len
(
tokens
)]
=
tok_to_orig_index
[
split_token_index
]
segment_ids
.
append
(
seg_q
)
paragraph_mask
.
append
(
0
)
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
tokens
.
append
(
"[SEP]"
)
split_token_index
)
segment_ids
.
append
(
seg_q
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
paragraph_mask
.
append
(
0
)
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
1
)
def
process_paragraph
(
seg_p
):
tokens
.
append
(
"[SEP]"
)
for
i
in
range
(
doc_span
.
length
):
segment_ids
.
append
(
1
)
split_token_index
=
doc_span
.
start
+
i
token_to_orig_map
[
len
(
tokens
)]
=
tok_to_orig_index
[
split_token_index
]
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
seg_p
)
paragraph_mask
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
seg_p
)
paragraph_mask
.
append
(
0
)
def
process_class
(
seg_class
):
class_index
=
len
(
segment_ids
)
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
seg_class
)
paragraph_mask
.
append
(
1
)
return
class_index
if
xlnet_format
:
seg_p
,
seg_q
,
seg_class
,
seg_pad
=
0
,
1
,
2
,
3
process_paragraph
(
seg_p
)
process_query
(
seg_q
)
class_index
=
process_class
(
seg_class
)
else
:
seg_p
,
seg_q
,
seg_class
,
seg_pad
=
1
,
0
,
0
,
0
class_index
=
process_class
(
seg_class
)
process_query
(
seg_q
)
process_paragraph
(
seg_p
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
...
@@ -329,11 +368,13 @@ def convert_examples_to_features(examples,
...
@@ -329,11 +368,13 @@ def convert_examples_to_features(examples,
while
len
(
input_ids
)
<
max_seq_length
:
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_pad
)
paragraph_mask
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
paragraph_mask
)
==
max_seq_length
start_position
=
None
start_position
=
None
end_position
=
None
end_position
=
None
...
@@ -350,7 +391,7 @@ def convert_examples_to_features(examples,
...
@@ -350,7 +391,7 @@ def convert_examples_to_features(examples,
start_position
=
0
start_position
=
0
end_position
=
0
end_position
=
0
else
:
else
:
doc_offset
=
len
(
query_tokens
)
+
2
doc_offset
=
0
if
xlnet_format
else
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
...
@@ -377,6 +418,9 @@ def convert_examples_to_features(examples,
...
@@ -377,6 +418,9 @@ def convert_examples_to_features(examples,
logging
.
info
(
"input_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logging
.
info
(
"input_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"paragraph_mask: %s"
,
" "
.
join
(
[
str
(
x
)
for
x
in
paragraph_mask
]))
logging
.
info
(
"class_index: %d"
,
class_index
)
if
is_training
and
example
.
is_impossible
:
if
is_training
and
example
.
is_impossible
:
logging
.
info
(
"impossible example"
)
logging
.
info
(
"impossible example"
)
if
is_training
and
not
example
.
is_impossible
:
if
is_training
and
not
example
.
is_impossible
:
...
@@ -390,6 +434,8 @@ def convert_examples_to_features(examples,
...
@@ -390,6 +434,8 @@ def convert_examples_to_features(examples,
example_index
=
example_index
,
example_index
=
example_index
,
doc_span_index
=
doc_span_index
,
doc_span_index
=
doc_span_index
,
tokens
=
tokens
,
tokens
=
tokens
,
paragraph_mask
=
paragraph_mask
,
class_index
=
class_index
,
token_to_orig_map
=
token_to_orig_map
,
token_to_orig_map
=
token_to_orig_map
,
token_is_max_context
=
token_is_max_context
,
token_is_max_context
=
token_is_max_context
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -541,6 +587,7 @@ def postprocess_output(all_examples,
...
@@ -541,6 +587,7 @@ def postprocess_output(all_examples,
do_lower_case
,
do_lower_case
,
version_2_with_negative
=
False
,
version_2_with_negative
=
False
,
null_score_diff_threshold
=
0.0
,
null_score_diff_threshold
=
0.0
,
xlnet_format
=
False
,
verbose
=
False
):
verbose
=
False
):
"""Postprocess model output, to form predicton results."""
"""Postprocess model output, to form predicton results."""
...
@@ -570,45 +617,50 @@ def postprocess_output(all_examples,
...
@@ -570,45 +617,50 @@ def postprocess_output(all_examples,
null_end_logit
=
0
# the end logit at the slice with min null score
null_end_logit
=
0
# the end logit at the slice with min null score
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
# if we could have irrelevant answers, get the min score of irrelevant
# if we could have irrelevant answers, get the min score of irrelevant
if
version_2_with_negative
:
if
version_2_with_negative
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
xlnet_format
:
feature_null_score
=
result
.
class_logits
else
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
feature_null_score
<
score_null
:
if
feature_null_score
<
score_null
:
score_null
=
feature_null_score
score_null
=
feature_null_score
min_null_feature_index
=
feature_index
min_null_feature_index
=
feature_index
null_start_logit
=
result
.
start_logits
[
0
]
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
for
start_index
in
start_indexes
:
for
(
start_index
,
start_logit
,
for
end_index
in
end_indexes
:
end_index
,
end_logit
)
in
_get_best_indexes_and_logits
(
# We could hypothetically create invalid predictions, e.g., predict
result
=
result
,
# that the start of the span is in the question. We throw out all
n_best_size
=
n_best_size
,
# invalid predictions.
xlnet_format
=
xlnet_format
):
if
start_index
>=
len
(
feature
.
tokens
):
# We could hypothetically create invalid predictions, e.g., predict
continue
# that the start of the span is in the question. We throw out all
if
end_index
>=
len
(
feature
.
tokens
):
# invalid predictions.
continue
if
start_index
>=
len
(
feature
.
tokens
):
if
start_index
not
in
feature
.
token_to_orig_map
:
continue
continue
if
end_index
>=
len
(
feature
.
tokens
):
if
end_index
not
in
feature
.
token_to_orig_map
:
continue
continue
if
start_index
not
in
feature
.
token_to_orig_map
:
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
continue
continue
if
end_index
not
in
feature
.
token_to_orig_map
:
if
end_index
<
start_index
:
continue
continue
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
length
=
end_index
-
start_index
+
1
continue
if
length
>
max_answer_length
:
if
end_index
<
start_index
:
continue
continue
prelim_predictions
.
append
(
length
=
end_index
-
start_index
+
1
_PrelimPrediction
(
if
length
>
max_answer_length
:
feature_index
=
feature_index
,
continue
start_index
=
start_index
,
prelim_predictions
.
append
(
end_index
=
end_index
,
_PrelimPrediction
(
start_logit
=
result
.
start_logits
[
start_index
],
feature_index
=
feature_index
,
end_logit
=
result
.
end_logits
[
end_index
]))
start_index
=
start_index
,
end_index
=
end_index
,
if
version_2_with_negative
:
start_logit
=
start_logit
,
end_logit
=
end_logit
))
if
version_2_with_negative
and
not
xlnet_format
:
prelim_predictions
.
append
(
prelim_predictions
.
append
(
_PrelimPrediction
(
_PrelimPrediction
(
feature_index
=
min_null_feature_index
,
feature_index
=
min_null_feature_index
,
...
@@ -630,7 +682,7 @@ def postprocess_output(all_examples,
...
@@ -630,7 +682,7 @@ def postprocess_output(all_examples,
if
len
(
nbest
)
>=
n_best_size
:
if
len
(
nbest
)
>=
n_best_size
:
break
break
feature
=
features
[
pred
.
feature_index
]
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>
0
:
# this is a non-null prediction
if
pred
.
start_index
>
0
or
xlnet_format
:
# this is a non-null prediction
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
...
@@ -663,7 +715,7 @@ def postprocess_output(all_examples,
...
@@ -663,7 +715,7 @@ def postprocess_output(all_examples,
end_logit
=
pred
.
end_logit
))
end_logit
=
pred
.
end_logit
))
# if we didn't inlude the empty option in the n-best, inlcude it
# if we didn't inlude the empty option in the n-best, inlcude it
if
version_2_with_negative
:
if
version_2_with_negative
and
not
xlnet_format
:
if
""
not
in
seen_predictions
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
nbest
.
append
(
_NbestPrediction
(
_NbestPrediction
(
...
@@ -704,13 +756,18 @@ def postprocess_output(all_examples,
...
@@ -704,13 +756,18 @@ def postprocess_output(all_examples,
# pytype: disable=attribute-error
# pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold
# predict "" iff the null score - the score of best non-null > threshold
if
best_non_null_entry
is
not
None
:
if
best_non_null_entry
is
not
None
:
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
if
xlnet_format
:
best_non_null_entry
.
end_logit
)
score_diff
=
score_null
scores_diff_json
[
example
.
qas_id
]
=
score_diff
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
else
:
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
else
:
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
else
:
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
else
:
else
:
logging
.
warning
(
"best_non_null_entry is None"
)
logging
.
warning
(
"best_non_null_entry is None"
)
scores_diff_json
[
example
.
qas_id
]
=
score_null
scores_diff_json
[
example
.
qas_id
]
=
score_null
...
@@ -822,16 +879,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
...
@@ -822,16 +879,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
return
output_text
return
output_text
def
_get_best_indexes
(
logits
,
n_best_size
):
def
_get_best_indexes_and_logits
(
result
,
"""Get the n-best logits from a list."""
n_best_size
,
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
xlnet_format
=
False
):
"""Generates the n-best indexes and logits from a list."""
best_indexes
=
[]
if
xlnet_format
:
for
i
in
range
(
len
(
index_and_score
)):
# pylint: disable=consider-using-enumerate
for
i
in
range
(
n_best_size
):
if
i
>=
n_best_size
:
for
j
in
range
(
n_best_size
):
break
j_index
=
i
*
n_best_size
+
j
best_indexes
.
append
(
index_and_score
[
i
][
0
])
yield
(
result
.
start_indexes
[
i
],
result
.
start_logits
[
i
],
return
best_indexes
result
.
end_indexes
[
j_index
],
result
.
end_logits
[
j_index
])
else
:
start_index_and_score
=
sorted
(
enumerate
(
result
.
start_logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
end_index_and_score
=
sorted
(
enumerate
(
result
.
end_logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
for
i
in
range
(
len
(
start_index_and_score
)):
if
i
>=
n_best_size
:
break
for
j
in
range
(
len
(
end_index_and_score
)):
if
j
>=
n_best_size
:
break
yield
(
start_index_and_score
[
i
][
0
],
start_index_and_score
[
i
][
1
],
end_index_and_score
[
j
][
0
],
end_index_and_score
[
j
][
1
])
def
_compute_softmax
(
scores
):
def
_compute_softmax
(
scores
):
...
@@ -864,7 +934,8 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -864,7 +934,8 @@ def generate_tf_record_from_json_file(input_file_path,
do_lower_case
=
True
,
do_lower_case
=
True
,
max_query_length
=
64
,
max_query_length
=
64
,
doc_stride
=
128
,
doc_stride
=
128
,
version_2_with_negative
=
False
):
version_2_with_negative
=
False
,
xlnet_format
=
False
):
"""Generates and saves training data into a tf record file."""
"""Generates and saves training data into a tf record file."""
train_examples
=
read_squad_examples
(
train_examples
=
read_squad_examples
(
input_file
=
input_file_path
,
input_file
=
input_file_path
,
...
@@ -880,7 +951,8 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -880,7 +951,8 @@ def generate_tf_record_from_json_file(input_file_path,
doc_stride
=
doc_stride
,
doc_stride
=
doc_stride
,
max_query_length
=
max_query_length
,
max_query_length
=
max_query_length
,
is_training
=
True
,
is_training
=
True
,
output_fn
=
train_writer
.
process_feature
)
output_fn
=
train_writer
.
process_feature
,
xlnet_format
=
xlnet_format
)
train_writer
.
close
()
train_writer
.
close
()
meta_data
=
{
meta_data
=
{
...
...
official/nlp/data/squad_lib_sp.py
View file @
31e4a64d
...
@@ -645,16 +645,11 @@ def postprocess_output(all_examples,
...
@@ -645,16 +645,11 @@ def postprocess_output(all_examples,
do_lower_case
,
do_lower_case
,
version_2_with_negative
=
False
,
version_2_with_negative
=
False
,
null_score_diff_threshold
=
0.0
,
null_score_diff_threshold
=
0.0
,
xlnet_format
=
False
,
verbose
=
False
):
verbose
=
False
):
"""Postprocess model output, to form predicton results."""
"""Postprocess model output, to form predicton results."""
del
do_lower_case
,
verbose
del
do_lower_case
,
verbose
# XLNet emits further predictions for start, end indexes and impossibility
# classifications.
xlnet_format
=
(
hasattr
(
all_results
[
0
],
"start_indexes"
)
and
all_results
[
0
].
start_indexes
is
not
None
)
example_index_to_features
=
collections
.
defaultdict
(
list
)
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
...
@@ -904,9 +899,9 @@ class FeatureWriter(object):
...
@@ -904,9 +899,9 @@ class FeatureWriter(object):
features
[
"input_ids"
]
=
create_int_feature
(
feature
.
input_ids
)
features
[
"input_ids"
]
=
create_int_feature
(
feature
.
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
if
feature
.
paragraph_mask
:
if
feature
.
paragraph_mask
is
not
None
:
features
[
"paragraph_mask"
]
=
create_int_feature
(
feature
.
paragraph_mask
)
features
[
"paragraph_mask"
]
=
create_int_feature
(
feature
.
paragraph_mask
)
if
feature
.
class_index
:
if
feature
.
class_index
is
not
None
:
features
[
"class_index"
]
=
create_int_feature
([
feature
.
class_index
])
features
[
"class_index"
]
=
create_int_feature
([
feature
.
class_index
])
if
self
.
is_training
:
if
self
.
is_training
:
...
...
official/nlp/modeling/models/xlnet.py
View file @
31e4a64d
...
@@ -150,6 +150,15 @@ class XLNetSpanLabeler(tf.keras.Model):
...
@@ -150,6 +150,15 @@ class XLNetSpanLabeler(tf.keras.Model):
'span_labeling_activation'
:
span_labeling_activation
,
'span_labeling_activation'
:
span_labeling_activation
,
'initializer'
:
initializer
,
'initializer'
:
initializer
,
}
}
network_config
=
network
.
get_config
()
try
:
input_width
=
network_config
[
'inner_size'
]
self
.
_xlnet_base
=
True
except
KeyError
:
# BertEncoder uses 'intermediate_size' due to legacy naming.
input_width
=
network_config
[
'intermediate_size'
]
self
.
_xlnet_base
=
False
self
.
_network
=
network
self
.
_network
=
network
self
.
_initializer
=
initializer
self
.
_initializer
=
initializer
self
.
_start_n_top
=
start_n_top
self
.
_start_n_top
=
start_n_top
...
@@ -157,7 +166,7 @@ class XLNetSpanLabeler(tf.keras.Model):
...
@@ -157,7 +166,7 @@ class XLNetSpanLabeler(tf.keras.Model):
self
.
_dropout_rate
=
dropout_rate
self
.
_dropout_rate
=
dropout_rate
self
.
_activation
=
span_labeling_activation
self
.
_activation
=
span_labeling_activation
self
.
span_labeling
=
networks
.
XLNetSpanLabeling
(
self
.
span_labeling
=
networks
.
XLNetSpanLabeling
(
input_width
=
network
.
get_config
()[
'inner_size'
]
,
input_width
=
input_width
,
start_n_top
=
self
.
_start_n_top
,
start_n_top
=
self
.
_start_n_top
,
end_n_top
=
self
.
_end_n_top
,
end_n_top
=
self
.
_end_n_top
,
activation
=
self
.
_activation
,
activation
=
self
.
_activation
,
...
@@ -165,17 +174,25 @@ class XLNetSpanLabeler(tf.keras.Model):
...
@@ -165,17 +174,25 @@ class XLNetSpanLabeler(tf.keras.Model):
initializer
=
self
.
_initializer
)
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_word_ids'
]
input_
word_
ids
=
inputs
[
'input_word_ids'
]
segment
_ids
=
inputs
[
'input_type_ids'
]
input_type
_ids
=
inputs
[
'input_type_ids'
]
input_mask
=
inputs
[
'input_mask'
]
input_mask
=
inputs
[
'input_mask'
]
class_index
=
inputs
[
'class_index'
]
class_index
=
inputs
[
'class_index'
]
paragraph_mask
=
inputs
[
'paragraph_mask'
]
paragraph_mask
=
inputs
[
'paragraph_mask'
]
start_positions
=
inputs
.
get
(
'start_positions'
,
None
)
start_positions
=
inputs
.
get
(
'start_positions'
,
None
)
attention_output
,
_
=
self
.
_network
(
if
self
.
_xlnet_base
:
input_ids
=
input_ids
,
attention_output
,
_
=
self
.
_network
(
segment_ids
=
segment_ids
,
input_ids
=
input_word_ids
,
input_mask
=
input_mask
)
segment_ids
=
input_type_ids
,
input_mask
=
input_mask
)
else
:
network_output_dict
=
self
.
_network
(
dict
(
input_word_ids
=
input_word_ids
,
input_type_ids
=
input_type_ids
,
input_mask
=
input_mask
))
attention_output
=
network_output_dict
[
'sequence_output'
]
outputs
=
self
.
span_labeling
(
outputs
=
self
.
span_labeling
(
sequence_data
=
attention_output
,
sequence_data
=
attention_output
,
class_index
=
class_index
,
class_index
=
class_index
,
...
...
official/nlp/tasks/question_answering.py
View file @
31e4a64d
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Question answering task."""
"""Question answering task."""
import
functools
import
json
import
json
import
os
import
os
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
...
@@ -143,6 +144,9 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -143,6 +144,9 @@ class QuestionAnsweringTask(base_task.Task):
eval_features
.
append
(
feature
)
eval_features
.
append
(
feature
)
eval_writer
.
process_feature
(
feature
)
eval_writer
.
process_feature
(
feature
)
# XLNet preprocesses SQuAD examples in a P, Q, class order whereas
# BERT preprocesses in a class, Q, P order.
xlnet_ordering
=
self
.
task_config
.
model
.
encoder
.
type
==
'xlnet'
kwargs
=
dict
(
kwargs
=
dict
(
examples
=
eval_examples
,
examples
=
eval_examples
,
max_seq_length
=
params
.
seq_length
,
max_seq_length
=
params
.
seq_length
,
...
@@ -150,14 +154,14 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -150,14 +154,14 @@ class QuestionAnsweringTask(base_task.Task):
max_query_length
=
params
.
query_length
,
max_query_length
=
params
.
query_length
,
is_training
=
False
,
is_training
=
False
,
output_fn
=
_append_feature
,
output_fn
=
_append_feature
,
batch_size
=
params
.
global_batch_size
)
batch_size
=
params
.
global_batch_size
,
xlnet_format
=
xlnet_ordering
)
if
params
.
tokenization
==
'SentencePiece'
:
if
params
.
tokenization
==
'SentencePiece'
:
# squad_lib_sp requires one more argument 'do_lower_case'.
# squad_lib_sp requires one more argument 'do_lower_case'.
kwargs
[
'do_lower_case'
]
=
params
.
do_lower_case
kwargs
[
'do_lower_case'
]
=
params
.
do_lower_case
kwargs
[
'tokenizer'
]
=
tokenization
.
FullSentencePieceTokenizer
(
kwargs
[
'tokenizer'
]
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
params
.
vocab_file
)
sp_model_file
=
params
.
vocab_file
)
kwargs
[
'xlnet_format'
]
=
self
.
task_config
.
model
.
encoder
.
type
==
'xlnet'
elif
params
.
tokenization
==
'WordPiece'
:
elif
params
.
tokenization
==
'WordPiece'
:
kwargs
[
'tokenizer'
]
=
tokenization
.
FullTokenizer
(
kwargs
[
'tokenizer'
]
=
tokenization
.
FullTokenizer
(
vocab_file
=
params
.
vocab_file
,
do_lower_case
=
params
.
do_lower_case
)
vocab_file
=
params
.
vocab_file
,
do_lower_case
=
params
.
do_lower_case
)
...
@@ -175,24 +179,25 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -175,24 +179,25 @@ class QuestionAnsweringTask(base_task.Task):
return
eval_writer
.
filename
,
eval_examples
,
eval_features
return
eval_writer
.
filename
,
eval_examples
,
eval_features
def
_dummy_data
(
self
,
params
,
_
):
"""Returns dummy data."""
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
y
=
dict
(
start_positions
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
),
end_positions
=
tf
.
constant
(
1
,
dtype
=
tf
.
int32
),
is_impossible
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
))
return
x
,
y
def
build_inputs
(
self
,
params
,
input_context
=
None
):
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for sentence_prediction task."""
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
==
'dummy'
:
if
params
.
input_path
==
'dummy'
:
# Dummy training data for unit test.
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
y
=
dict
(
start_positions
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
),
end_positions
=
tf
.
constant
(
1
,
dtype
=
tf
.
int32
),
is_impossible
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
))
return
(
x
,
y
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dummy_data
=
functools
.
partial
(
self
.
_dummy_data
,
params
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
...
@@ -278,6 +283,7 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -278,6 +283,7 @@ class QuestionAnsweringTask(base_task.Task):
self
.
task_config
.
validation_data
.
version_2_with_negative
),
self
.
task_config
.
validation_data
.
version_2_with_negative
),
null_score_diff_threshold
=
(
null_score_diff_threshold
=
(
self
.
task_config
.
null_score_diff_threshold
),
self
.
task_config
.
null_score_diff_threshold
),
xlnet_format
=
self
.
task_config
.
validation_data
.
xlnet_format
,
verbose
=
False
))
verbose
=
False
))
with
tf
.
io
.
gfile
.
GFile
(
self
.
task_config
.
validation_data
.
input_path
,
with
tf
.
io
.
gfile
.
GFile
(
self
.
task_config
.
validation_data
.
input_path
,
...
@@ -382,6 +388,24 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
...
@@ -382,6 +388,24 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
'end_positions'
:
end_logits
,
'end_positions'
:
end_logits
,
})
})
def
_dummy_data
(
self
,
params
,
_
):
"""Returns dummy data."""
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
zero
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
,
class_index
=
zero
,
is_impossible
=
zero
,
paragraph_mask
=
dummy_ids
,
start_positions
=
tf
.
zeros
((
1
),
dtype
=
tf
.
int32
))
y
=
dict
(
start_positions
=
tf
.
zeros
((
1
),
dtype
=
tf
.
int32
),
end_positions
=
tf
.
ones
((
1
),
dtype
=
tf
.
int32
),
is_impossible
=
zero
)
return
x
,
y
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
features
,
_
=
inputs
features
,
_
=
inputs
unique_ids
=
features
.
pop
(
'unique_ids'
)
unique_ids
=
features
.
pop
(
'unique_ids'
)
...
@@ -468,5 +492,6 @@ def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
...
@@ -468,5 +492,6 @@ def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
task
.
task_config
.
validation_data
.
do_lower_case
,
task
.
task_config
.
validation_data
.
do_lower_case
,
version_2_with_negative
=
(
params
.
version_2_with_negative
),
version_2_with_negative
=
(
params
.
version_2_with_negative
),
null_score_diff_threshold
=
task
.
task_config
.
null_score_diff_threshold
,
null_score_diff_threshold
=
task
.
task_config
.
null_score_diff_threshold
,
xlnet_format
=
task
.
task_config
.
validation_data
.
xlnet_format
,
verbose
=
False
))
verbose
=
False
))
return
all_predictions
,
all_nbest
,
scores_diff
return
all_predictions
,
all_nbest
,
scores_diff
official/nlp/tasks/question_answering_test.py
View file @
31e4a64d
...
@@ -186,5 +186,93 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -186,5 +186,93 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEmpty
(
scores_diff
)
self
.
assertEmpty
(
scores_diff
)
class
XLNetQuestionAnsweringTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
XLNetQuestionAnsweringTaskTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
EncoderConfig
(
type
=
"xlnet"
,
xlnet
=
encoders
.
XLNetEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
))
self
.
_train_data_config
=
question_answering_dataloader
.
QADataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
2
,
xlnet_format
=
True
)
val_data
=
{
"version"
:
"2.0"
,
"data"
:
[{
"paragraphs"
:
[{
"context"
:
"Sky is blue."
,
"qas"
:
[{
"question"
:
"What is blue?"
,
"id"
:
"1234"
,
"answers"
:
[{
"text"
:
"Sky"
,
"answer_start"
:
0
},
{
"text"
:
"Sky"
,
"answer_start"
:
0
},
{
"text"
:
"Sky"
,
"answer_start"
:
0
}]
}]
}]
}]
}
self
.
_val_input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"val_data.json"
)
with
tf
.
io
.
gfile
.
GFile
(
self
.
_val_input_path
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
val_data
,
indent
=
4
)
+
"
\n
"
)
self
.
_test_vocab
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
self
.
_test_vocab
,
"w"
)
as
writer
:
writer
.
write
(
"[PAD]
\n
[UNK]
\n
[CLS]
\n
[SEP]
\n
[MASK]
\n
sky
\n
is
\n
blue
\n
"
)
def
_get_validation_data_config
(
self
):
return
question_answering_dataloader
.
QADataConfig
(
is_training
=
False
,
input_path
=
self
.
_val_input_path
,
input_preprocessed_data_path
=
self
.
get_temp_dir
(),
seq_length
=
128
,
global_batch_size
=
2
,
version_2_with_negative
=
True
,
vocab_file
=
self
.
_test_vocab
,
tokenization
=
"WordPiece"
,
do_lower_case
=
True
,
xlnet_format
=
True
)
def
_run_task
(
self
,
config
):
task
=
question_answering
.
XLNetQuestionAnsweringTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
task
.
initialize
(
model
)
train_dataset
=
task
.
build_inputs
(
config
.
train_data
)
train_iterator
=
iter
(
train_dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
train_iterator
),
model
,
optimizer
,
metrics
=
metrics
)
val_dataset
=
task
.
build_inputs
(
config
.
validation_data
)
val_iterator
=
iter
(
val_dataset
)
logs
=
task
.
validation_step
(
next
(
val_iterator
),
model
,
metrics
=
metrics
)
# Mock that `logs` is from one replica.
logs
=
{
x
:
(
logs
[
x
],)
for
x
in
logs
}
logs
=
task
.
aggregate_logs
(
step_outputs
=
logs
)
metrics
=
task
.
reduce_aggregated_logs
(
logs
)
self
.
assertIn
(
"final_f1"
,
metrics
)
def
test_task
(
self
):
config
=
question_answering
.
XLNetQuestionAnsweringConfig
(
init_checkpoint
=
""
,
n_best_size
=
5
,
model
=
question_answering
.
ModelConfig
(
encoder
=
self
.
_encoder_config
),
train_data
=
self
.
_train_data_config
,
validation_data
=
self
.
_get_validation_data_config
())
self
.
_run_task
(
config
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment