Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f728ffc6
Commit
f728ffc6
authored
Nov 05, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 05, 2020
Browse files
Add in XLNet style SQuAD preprocessing to TF-NLP.
PiperOrigin-RevId: 340897846
parent
b5c6170e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
59 deletions
+157
-59
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+14
-3
official/nlp/data/question_answering_dataloader.py
official/nlp/data/question_answering_dataloader.py
+12
-1
official/nlp/data/squad_lib_sp.py
official/nlp/data/squad_lib_sp.py
+131
-55
No files found.
official/nlp/data/create_finetuning_data.py
View file @
f728ffc6
...
@@ -100,6 +100,11 @@ flags.DEFINE_bool(
...
@@ -100,6 +100,11 @@ flags.DEFINE_bool(
"version_2_with_negative"
,
False
,
"version_2_with_negative"
,
False
,
"If true, the SQuAD examples contain some that do not have an answer."
)
"If true, the SQuAD examples contain some that do not have an answer."
)
flags
.
DEFINE_bool
(
"xlnet_format"
,
False
,
"If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order."
)
# Shared flags across BERT fine-tuning tasks.
# Shared flags across BERT fine-tuning tasks.
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
"The vocabulary file that the BERT model was trained on."
)
...
@@ -263,9 +268,15 @@ def generate_squad_dataset():
...
@@ -263,9 +268,15 @@ def generate_squad_dataset():
else
:
else
:
assert
FLAGS
.
tokenization
==
"SentencePiece"
assert
FLAGS
.
tokenization
==
"SentencePiece"
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
sp_model_file
,
input_file_path
=
FLAGS
.
squad_data_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
sp_model_file
=
FLAGS
.
sp_model_file
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
output_path
=
FLAGS
.
train_data_output_path
,
max_seq_length
=
FLAGS
.
max_seq_length
,
do_lower_case
=
FLAGS
.
do_lower_case
,
max_query_length
=
FLAGS
.
max_query_length
,
doc_stride
=
FLAGS
.
doc_stride
,
xlnet_format
=
FLAGS
.
xlnet_format
,
version_2_with_negative
=
FLAGS
.
version_2_with_negative
)
def
generate_retrieval_dataset
():
def
generate_retrieval_dataset
():
...
...
official/nlp/data/question_answering_dataloader.py
View file @
f728ffc6
...
@@ -42,6 +42,7 @@ class QADataConfig(cfg.DataConfig):
...
@@ -42,6 +42,7 @@ class QADataConfig(cfg.DataConfig):
vocab_file
:
str
=
''
vocab_file
:
str
=
''
tokenization
:
str
=
'WordPiece'
# WordPiece or SentencePiece
tokenization
:
str
=
'WordPiece'
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
do_lower_case
:
bool
=
True
xlnet_format
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
QADataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
QADataConfig
)
...
@@ -52,6 +53,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
...
@@ -52,6 +53,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
self
.
_params
=
params
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_seq_length
=
params
.
seq_length
self
.
_is_training
=
params
.
is_training
self
.
_is_training
=
params
.
is_training
self
.
_xlnet_format
=
params
.
xlnet_format
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
...
@@ -60,6 +62,13 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
...
@@ -60,6 +62,13 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
}
if
self
.
_xlnet_format
:
name_to_features
[
'class_index'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'paragraph_mask'
]
=
tf
.
io
.
FixedLenFeature
(
[
self
.
_seq_length
],
tf
.
int64
)
if
self
.
_is_training
:
name_to_features
[
'is_impossible'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
if
self
.
_is_training
:
if
self
.
_is_training
:
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
...
@@ -81,7 +90,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
...
@@ -81,7 +90,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
,
y
=
{},
{}
x
,
y
=
{},
{}
for
name
,
tensor
in
record
.
items
():
for
name
,
tensor
in
record
.
items
():
if
name
in
(
'start_positions'
,
'end_positions'
):
if
name
in
(
'start_positions'
,
'end_positions'
,
'is_impossible'
):
y
[
name
]
=
tensor
y
[
name
]
=
tensor
elif
name
==
'input_ids'
:
elif
name
==
'input_ids'
:
x
[
'input_word_ids'
]
=
tensor
x
[
'input_word_ids'
]
=
tensor
...
@@ -89,6 +98,8 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
...
@@ -89,6 +98,8 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
x
[
'input_type_ids'
]
=
tensor
x
[
'input_type_ids'
]
=
tensor
else
:
else
:
x
[
name
]
=
tensor
x
[
name
]
=
tensor
if
name
==
'start_positions'
and
self
.
_xlnet_format
:
x
[
name
]
=
tensor
return
(
x
,
y
)
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
...
...
official/nlp/data/squad_lib_sp.py
View file @
f728ffc6
...
@@ -86,6 +86,8 @@ class InputFeatures(object):
...
@@ -86,6 +86,8 @@ class InputFeatures(object):
input_mask
,
input_mask
,
segment_ids
,
segment_ids
,
paragraph_len
,
paragraph_len
,
class_index
=
None
,
paragraph_mask
=
None
,
start_position
=
None
,
start_position
=
None
,
end_position
=
None
,
end_position
=
None
,
is_impossible
=
None
):
is_impossible
=
None
):
...
@@ -98,8 +100,10 @@ class InputFeatures(object):
...
@@ -98,8 +100,10 @@ class InputFeatures(object):
self
.
tokens
=
tokens
self
.
tokens
=
tokens
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
input_mask
=
input_mask
self
.
paragraph_mask
=
paragraph_mask
self
.
segment_ids
=
segment_ids
self
.
segment_ids
=
segment_ids
self
.
paragraph_len
=
paragraph_len
self
.
paragraph_len
=
paragraph_len
self
.
class_index
=
class_index
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
self
.
is_impossible
=
is_impossible
...
@@ -194,6 +198,7 @@ def convert_examples_to_features(examples,
...
@@ -194,6 +198,7 @@ def convert_examples_to_features(examples,
is_training
,
is_training
,
output_fn
,
output_fn
,
do_lower_case
,
do_lower_case
,
xlnet_format
=
False
,
batch_size
=
None
):
batch_size
=
None
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
cnt_pos
,
cnt_neg
=
0
,
0
cnt_pos
,
cnt_neg
=
0
,
0
...
@@ -353,6 +358,7 @@ def convert_examples_to_features(examples,
...
@@ -353,6 +358,7 @@ def convert_examples_to_features(examples,
"DocSpan"
,
[
"start"
,
"length"
])
"DocSpan"
,
[
"start"
,
"length"
])
doc_spans
=
[]
doc_spans
=
[]
start_offset
=
0
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
while
start_offset
<
len
(
all_doc_tokens
):
length
=
len
(
all_doc_tokens
)
-
start_offset
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
if
length
>
max_tokens_for_doc
:
...
@@ -367,17 +373,25 @@ def convert_examples_to_features(examples,
...
@@ -367,17 +373,25 @@ def convert_examples_to_features(examples,
token_is_max_context
=
{}
token_is_max_context
=
{}
segment_ids
=
[]
segment_ids
=
[]
# Paragraph mask used in XLNet.
# 1 represents paragraph and class tokens.
# 0 represents query and other special tokens.
paragraph_mask
=
[]
cur_tok_start_to_orig_index
=
[]
cur_tok_start_to_orig_index
=
[]
cur_tok_end_to_orig_index
=
[]
cur_tok_end_to_orig_index
=
[]
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[CLS]"
))
# pylint: disable=cell-var-from-loop
segment_ids
.
append
(
0
)
def
process_query
(
seg_q
):
for
token
in
query_tokens
:
for
token
in
query_tokens
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_q
)
paragraph_mask
.
append
(
0
)
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[SEP]"
))
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[SEP]"
))
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_q
)
paragraph_mask
.
append
(
0
)
def
process_paragraph
(
seg_p
):
for
i
in
range
(
doc_span
.
length
):
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
split_token_index
=
doc_span
.
start
+
i
...
@@ -390,11 +404,31 @@ def convert_examples_to_features(examples,
...
@@ -390,11 +404,31 @@ def convert_examples_to_features(examples,
split_token_index
)
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
1
)
segment_ids
.
append
(
seg_p
)
paragraph_mask
.
append
(
1
)
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[SEP]"
))
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[SEP]"
))
segment_ids
.
append
(
1
)
segment_ids
.
append
(
seg_p
)
paragraph_mask
.
append
(
0
)
return
len
(
tokens
)
def
process_class
(
seg_class
):
class_index
=
len
(
segment_ids
)
tokens
.
append
(
tokenizer
.
sp_model
.
PieceToId
(
"[CLS]"
))
segment_ids
.
append
(
seg_class
)
paragraph_mask
.
append
(
1
)
return
class_index
if
xlnet_format
:
seg_p
,
seg_q
,
seg_class
,
seg_pad
=
0
,
1
,
2
,
3
paragraph_len
=
process_paragraph
(
seg_p
)
process_query
(
seg_q
)
class_index
=
process_class
(
seg_class
)
else
:
seg_p
,
seg_q
,
seg_class
,
seg_pad
=
1
,
0
,
0
,
0
class_index
=
process_class
(
seg_class
)
process_query
(
seg_q
)
paragraph_len
=
process_paragraph
(
seg_p
)
paragraph_len
=
len
(
tokens
)
input_ids
=
tokens
input_ids
=
tokens
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# The mask has 1 for real tokens and 0 for padding tokens. Only real
...
@@ -405,11 +439,13 @@ def convert_examples_to_features(examples,
...
@@ -405,11 +439,13 @@ def convert_examples_to_features(examples,
while
len
(
input_ids
)
<
max_seq_length
:
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_pad
)
paragraph_mask
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
paragraph_mask
)
==
max_seq_length
span_is_impossible
=
example
.
is_impossible
span_is_impossible
=
example
.
is_impossible
start_position
=
None
start_position
=
None
...
@@ -429,13 +465,13 @@ def convert_examples_to_features(examples,
...
@@ -429,13 +465,13 @@ def convert_examples_to_features(examples,
end_position
=
0
end_position
=
0
span_is_impossible
=
True
span_is_impossible
=
True
else
:
else
:
doc_offset
=
len
(
query_tokens
)
+
2
doc_offset
=
0
if
xlnet_format
else
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
is_training
and
span_is_impossible
:
if
is_training
and
span_is_impossible
:
start_position
=
0
start_position
=
class_index
end_position
=
0
end_position
=
class_index
if
example_index
<
20
:
if
example_index
<
20
:
logging
.
info
(
"*** Example ***"
)
logging
.
info
(
"*** Example ***"
)
...
@@ -455,6 +491,9 @@ def convert_examples_to_features(examples,
...
@@ -455,6 +491,9 @@ def convert_examples_to_features(examples,
logging
.
info
(
"input_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logging
.
info
(
"input_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"paragraph_mask: %s"
,
" "
.
join
(
[
str
(
x
)
for
x
in
paragraph_mask
]))
logging
.
info
(
"class_index: %d"
,
class_index
)
if
is_training
and
span_is_impossible
:
if
is_training
and
span_is_impossible
:
logging
.
info
(
"impossible example span"
)
logging
.
info
(
"impossible example span"
)
...
@@ -488,8 +527,10 @@ def convert_examples_to_features(examples,
...
@@ -488,8 +527,10 @@ def convert_examples_to_features(examples,
tokens
=
[
tokenizer
.
sp_model
.
IdToPiece
(
x
)
for
x
in
tokens
],
tokens
=
[
tokenizer
.
sp_model
.
IdToPiece
(
x
)
for
x
in
tokens
],
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
paragraph_mask
=
paragraph_mask
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
paragraph_len
=
paragraph_len
,
paragraph_len
=
paragraph_len
,
class_index
=
class_index
,
start_position
=
start_position
,
start_position
=
start_position
,
end_position
=
end_position
,
end_position
=
end_position
,
is_impossible
=
span_is_impossible
)
is_impossible
=
span_is_impossible
)
...
@@ -609,6 +650,11 @@ def postprocess_output(all_examples,
...
@@ -609,6 +650,11 @@ def postprocess_output(all_examples,
del
do_lower_case
,
verbose
del
do_lower_case
,
verbose
# XLNet emits further predictions for start, end indexes and impossibility
# classifications.
xlnet_format
=
(
hasattr
(
all_results
[
0
],
"start_indexes"
)
and
all_results
[
0
].
start_indexes
is
not
None
)
example_index_to_features
=
collections
.
defaultdict
(
list
)
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
...
@@ -636,19 +682,32 @@ def postprocess_output(all_examples,
...
@@ -636,19 +682,32 @@ def postprocess_output(all_examples,
null_end_logit
=
0
# the end logit at the slice with min null score
null_end_logit
=
0
# the end logit at the slice with min null score
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
# if we could have irrelevant answers, get the min score of irrelevant
# if we could have irrelevant answers, get the min score of irrelevant
if
version_2_with_negative
:
if
version_2_with_negative
:
if
xlnet_format
:
feature_null_score
=
result
.
class_logits
else
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
feature_null_score
<
score_null
:
if
feature_null_score
<
score_null
:
score_null
=
feature_null_score
score_null
=
feature_null_score
min_null_feature_index
=
feature_index
min_null_feature_index
=
feature_index
null_start_logit
=
result
.
start_logits
[
0
]
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
for
start_index
in
start_indexes
:
for
end_index
in
end_indexes
:
start_indexes_and_logits
=
_get_best_indexes_and_logits
(
doc_offset
=
feature
.
tokens
.
index
(
"[SEP]"
)
+
1
result
=
result
,
n_best_size
=
n_best_size
,
start
=
True
,
xlnet_format
=
xlnet_format
)
end_indexes_and_logits
=
_get_best_indexes_and_logits
(
result
=
result
,
n_best_size
=
n_best_size
,
start
=
False
,
xlnet_format
=
xlnet_format
)
doc_offset
=
0
if
xlnet_format
else
feature
.
tokens
.
index
(
"[SEP]"
)
+
1
for
start_index
,
start_logit
in
start_indexes_and_logits
:
for
end_index
,
end_logit
in
end_indexes_and_logits
:
# We could hypothetically create invalid predictions, e.g., predict
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# that the start of the span is in the question. We throw out all
# invalid predictions.
# invalid predictions.
...
@@ -656,10 +715,6 @@ def postprocess_output(all_examples,
...
@@ -656,10 +715,6 @@ def postprocess_output(all_examples,
continue
continue
if
end_index
-
doc_offset
>=
len
(
feature
.
tok_end_to_orig_index
):
if
end_index
-
doc_offset
>=
len
(
feature
.
tok_end_to_orig_index
):
continue
continue
# if start_index not in feature.tok_start_to_orig_index:
# continue
# if end_index not in feature.tok_end_to_orig_index:
# continue
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
continue
continue
if
end_index
<
start_index
:
if
end_index
<
start_index
:
...
@@ -672,10 +727,10 @@ def postprocess_output(all_examples,
...
@@ -672,10 +727,10 @@ def postprocess_output(all_examples,
feature_index
=
feature_index
,
feature_index
=
feature_index
,
start_index
=
start_index
-
doc_offset
,
start_index
=
start_index
-
doc_offset
,
end_index
=
end_index
-
doc_offset
,
end_index
=
end_index
-
doc_offset
,
start_logit
=
result
.
start_logit
s
[
start_index
]
,
start_logit
=
start_logit
,
end_logit
=
result
.
end_logit
s
[
end_index
]
))
end_logit
=
end_logit
))
if
version_2_with_negative
:
if
version_2_with_negative
and
not
xlnet_format
:
prelim_predictions
.
append
(
prelim_predictions
.
append
(
_PrelimPrediction
(
_PrelimPrediction
(
feature_index
=
min_null_feature_index
,
feature_index
=
min_null_feature_index
,
...
@@ -720,7 +775,7 @@ def postprocess_output(all_examples,
...
@@ -720,7 +775,7 @@ def postprocess_output(all_examples,
end_logit
=
pred
.
end_logit
))
end_logit
=
pred
.
end_logit
))
# if we didn't inlude the empty option in the n-best, inlcude it
# if we didn't inlude the empty option in the n-best, inlcude it
if
version_2_with_negative
:
if
version_2_with_negative
and
not
xlnet_format
:
if
""
not
in
seen_predictions
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
nbest
.
append
(
_NbestPrediction
(
_NbestPrediction
(
...
@@ -778,16 +833,30 @@ def write_to_json_files(json_records, json_file):
...
@@ -778,16 +833,30 @@ def write_to_json_files(json_records, json_file):
writer
.
write
(
json
.
dumps
(
json_records
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
json_records
,
indent
=
4
)
+
"
\n
"
)
def
_get_best_indexes
(
logits
,
n_best_size
):
def
_get_best_indexes_and_logits
(
result
,
"""Get the n-best logits from a list."""
n_best_size
,
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
start
=
False
,
xlnet_format
=
False
):
best_indexes
=
[]
"""Generates the n-best indexes and logits from a list."""
if
xlnet_format
:
for
i
in
range
(
n_best_size
):
for
j
in
range
(
n_best_size
):
j_index
=
i
*
n_best_size
+
j
if
start
:
yield
result
.
start_indexes
[
i
],
result
.
start_logits
[
i
]
else
:
yield
result
.
end_indexes
[
j_index
],
result
.
end_logits
[
j_index
]
else
:
if
start
:
logits
=
result
.
start_logits
else
:
logits
=
result
.
end_logits
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
for
i
in
range
(
len
(
index_and_score
)):
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
if
i
>=
n_best_size
:
break
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
yield
index_and_score
[
i
]
return
best_indexes
def
_compute_softmax
(
scores
):
def
_compute_softmax
(
scores
):
...
@@ -816,12 +885,13 @@ def _compute_softmax(scores):
...
@@ -816,12 +885,13 @@ def _compute_softmax(scores):
class
FeatureWriter
(
object
):
class
FeatureWriter
(
object
):
"""Writes InputFeature to TF example file."""
"""Writes InputFeature to TF example file."""
def
__init__
(
self
,
filename
,
is_training
):
def
__init__
(
self
,
filename
,
is_training
,
xlnet_format
=
False
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
num_features
=
0
self
.
num_features
=
0
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
filename
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
filename
))
self
.
_writer
=
tf
.
io
.
TFRecordWriter
(
filename
)
self
.
_writer
=
tf
.
io
.
TFRecordWriter
(
filename
)
self
.
_xlnet_format
=
xlnet_format
def
process_feature
(
self
,
feature
):
def
process_feature
(
self
,
feature
):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
...
@@ -837,6 +907,9 @@ class FeatureWriter(object):
...
@@ -837,6 +907,9 @@ class FeatureWriter(object):
features
[
"input_ids"
]
=
create_int_feature
(
feature
.
input_ids
)
features
[
"input_ids"
]
=
create_int_feature
(
feature
.
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
if
self
.
_xlnet_format
:
features
[
"paragraph_mask"
]
=
create_int_feature
(
feature
.
paragraph_mask
)
features
[
"class_index"
]
=
create_int_feature
([
feature
.
class_index
])
if
self
.
is_training
:
if
self
.
is_training
:
features
[
"start_positions"
]
=
create_int_feature
([
feature
.
start_position
])
features
[
"start_positions"
]
=
create_int_feature
([
feature
.
start_position
])
...
@@ -860,6 +933,7 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -860,6 +933,7 @@ def generate_tf_record_from_json_file(input_file_path,
do_lower_case
=
True
,
do_lower_case
=
True
,
max_query_length
=
64
,
max_query_length
=
64
,
doc_stride
=
128
,
doc_stride
=
128
,
xlnet_format
=
False
,
version_2_with_negative
=
False
):
version_2_with_negative
=
False
):
"""Generates and saves training data into a tf record file."""
"""Generates and saves training data into a tf record file."""
train_examples
=
read_squad_examples
(
train_examples
=
read_squad_examples
(
...
@@ -868,7 +942,8 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -868,7 +942,8 @@ def generate_tf_record_from_json_file(input_file_path,
version_2_with_negative
=
version_2_with_negative
)
version_2_with_negative
=
version_2_with_negative
)
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
sp_model_file
)
sp_model_file
=
sp_model_file
)
train_writer
=
FeatureWriter
(
filename
=
output_path
,
is_training
=
True
)
train_writer
=
FeatureWriter
(
filename
=
output_path
,
is_training
=
True
,
xlnet_format
=
xlnet_format
)
number_of_examples
=
convert_examples_to_features
(
number_of_examples
=
convert_examples_to_features
(
examples
=
train_examples
,
examples
=
train_examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -877,6 +952,7 @@ def generate_tf_record_from_json_file(input_file_path,
...
@@ -877,6 +952,7 @@ def generate_tf_record_from_json_file(input_file_path,
max_query_length
=
max_query_length
,
max_query_length
=
max_query_length
,
is_training
=
True
,
is_training
=
True
,
output_fn
=
train_writer
.
process_feature
,
output_fn
=
train_writer
.
process_feature
,
xlnet_format
=
xlnet_format
,
do_lower_case
=
do_lower_case
)
do_lower_case
=
do_lower_case
)
train_writer
.
close
()
train_writer
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment