"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "992b8922fce587472b2edc49944aeff5b64dc2b1"
Commit ef558ee5 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Add in XLNet style SQuAD preprocessing to TF-NLP.

PiperOrigin-RevId: 340897846
parent 51e2004c
...@@ -100,6 +100,11 @@ flags.DEFINE_bool( ...@@ -100,6 +100,11 @@ flags.DEFINE_bool(
"version_2_with_negative", False, "version_2_with_negative", False,
"If true, the SQuAD examples contain some that do not have an answer.") "If true, the SQuAD examples contain some that do not have an answer.")
flags.DEFINE_bool(
"xlnet_format", False,
"If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order.")
# Shared flags across BERT fine-tuning tasks. # Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None, flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.") "The vocabulary file that the BERT model was trained on.")
...@@ -263,9 +268,15 @@ def generate_squad_dataset(): ...@@ -263,9 +268,15 @@ def generate_squad_dataset():
else: else:
assert FLAGS.tokenization == "SentencePiece" assert FLAGS.tokenization == "SentencePiece"
return squad_lib_sp.generate_tf_record_from_json_file( return squad_lib_sp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.sp_model_file, input_file_path=FLAGS.squad_data_file,
FLAGS.train_data_output_path, FLAGS.max_seq_length, FLAGS.do_lower_case, sp_model_file=FLAGS.sp_model_file,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative) output_path=FLAGS.train_data_output_path,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
doc_stride=FLAGS.doc_stride,
xlnet_format=FLAGS.xlnet_format,
version_2_with_negative=FLAGS.version_2_with_negative)
def generate_retrieval_dataset(): def generate_retrieval_dataset():
......
...@@ -42,6 +42,7 @@ class QADataConfig(cfg.DataConfig): ...@@ -42,6 +42,7 @@ class QADataConfig(cfg.DataConfig):
vocab_file: str = '' vocab_file: str = ''
tokenization: str = 'WordPiece' # WordPiece or SentencePiece tokenization: str = 'WordPiece' # WordPiece or SentencePiece
do_lower_case: bool = True do_lower_case: bool = True
xlnet_format: bool = False
@data_loader_factory.register_data_loader_cls(QADataConfig) @data_loader_factory.register_data_loader_cls(QADataConfig)
...@@ -52,6 +53,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader): ...@@ -52,6 +53,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
self._is_training = params.is_training self._is_training = params.is_training
self._xlnet_format = params.xlnet_format
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
...@@ -60,6 +62,13 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader): ...@@ -60,6 +62,13 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
} }
if self._xlnet_format:
name_to_features['class_index'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['paragraph_mask'] = tf.io.FixedLenFeature(
[self._seq_length], tf.int64)
if self._is_training:
name_to_features['is_impossible'] = tf.io.FixedLenFeature([], tf.int64)
if self._is_training: if self._is_training:
name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64) name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
...@@ -81,7 +90,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader): ...@@ -81,7 +90,7 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
"""Parses raw tensors into a dict of tensors to be consumed by the model.""" """Parses raw tensors into a dict of tensors to be consumed by the model."""
x, y = {}, {} x, y = {}, {}
for name, tensor in record.items(): for name, tensor in record.items():
if name in ('start_positions', 'end_positions'): if name in ('start_positions', 'end_positions', 'is_impossible'):
y[name] = tensor y[name] = tensor
elif name == 'input_ids': elif name == 'input_ids':
x['input_word_ids'] = tensor x['input_word_ids'] = tensor
...@@ -89,6 +98,8 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader): ...@@ -89,6 +98,8 @@ class QuestionAnsweringDataLoader(data_loader.DataLoader):
x['input_type_ids'] = tensor x['input_type_ids'] = tensor
else: else:
x[name] = tensor x[name] = tensor
if name == 'start_positions' and self._xlnet_format:
x[name] = tensor
return (x, y) return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
......
...@@ -86,6 +86,8 @@ class InputFeatures(object): ...@@ -86,6 +86,8 @@ class InputFeatures(object):
input_mask, input_mask,
segment_ids, segment_ids,
paragraph_len, paragraph_len,
class_index=None,
paragraph_mask=None,
start_position=None, start_position=None,
end_position=None, end_position=None,
is_impossible=None): is_impossible=None):
...@@ -98,8 +100,10 @@ class InputFeatures(object): ...@@ -98,8 +100,10 @@ class InputFeatures(object):
self.tokens = tokens self.tokens = tokens
self.input_ids = input_ids self.input_ids = input_ids
self.input_mask = input_mask self.input_mask = input_mask
self.paragraph_mask = paragraph_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.paragraph_len = paragraph_len self.paragraph_len = paragraph_len
self.class_index = class_index
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible self.is_impossible = is_impossible
...@@ -194,6 +198,7 @@ def convert_examples_to_features(examples, ...@@ -194,6 +198,7 @@ def convert_examples_to_features(examples,
is_training, is_training,
output_fn, output_fn,
do_lower_case, do_lower_case,
xlnet_format=False,
batch_size=None): batch_size=None):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
cnt_pos, cnt_neg = 0, 0 cnt_pos, cnt_neg = 0, 0
...@@ -353,6 +358,7 @@ def convert_examples_to_features(examples, ...@@ -353,6 +358,7 @@ def convert_examples_to_features(examples,
"DocSpan", ["start", "length"]) "DocSpan", ["start", "length"])
doc_spans = [] doc_spans = []
start_offset = 0 start_offset = 0
while start_offset < len(all_doc_tokens): while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc: if length > max_tokens_for_doc:
...@@ -367,34 +373,62 @@ def convert_examples_to_features(examples, ...@@ -367,34 +373,62 @@ def convert_examples_to_features(examples,
token_is_max_context = {} token_is_max_context = {}
segment_ids = [] segment_ids = []
# Paragraph mask used in XLNet.
# 1 represents paragraph and class tokens.
# 0 represents query and other special tokens.
paragraph_mask = []
cur_tok_start_to_orig_index = [] cur_tok_start_to_orig_index = []
cur_tok_end_to_orig_index = [] cur_tok_end_to_orig_index = []
tokens.append(tokenizer.sp_model.PieceToId("[CLS]")) # pylint: disable=cell-var-from-loop
segment_ids.append(0) def process_query(seg_q):
for token in query_tokens: for token in query_tokens:
tokens.append(token) tokens.append(token)
segment_ids.append(0) segment_ids.append(seg_q)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]")) paragraph_mask.append(0)
segment_ids.append(0) tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
segment_ids.append(seg_q)
for i in range(doc_span.length): paragraph_mask.append(0)
split_token_index = doc_span.start + i
def process_paragraph(seg_p):
cur_tok_start_to_orig_index.append( for i in range(doc_span.length):
tok_start_to_orig_index[split_token_index]) split_token_index = doc_span.start + i
cur_tok_end_to_orig_index.append(
tok_end_to_orig_index[split_token_index]) cur_tok_start_to_orig_index.append(
tok_start_to_orig_index[split_token_index])
is_max_context = _check_is_max_context(doc_spans, doc_span_index, cur_tok_end_to_orig_index.append(
split_token_index) tok_end_to_orig_index[split_token_index])
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index]) is_max_context = _check_is_max_context(doc_spans, doc_span_index,
segment_ids.append(1) split_token_index)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]")) token_is_max_context[len(tokens)] = is_max_context
segment_ids.append(1) tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(seg_p)
paragraph_len = len(tokens) paragraph_mask.append(1)
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
segment_ids.append(seg_p)
paragraph_mask.append(0)
return len(tokens)
def process_class(seg_class):
class_index = len(segment_ids)
tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
segment_ids.append(seg_class)
paragraph_mask.append(1)
return class_index
if xlnet_format:
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
paragraph_len = process_paragraph(seg_p)
process_query(seg_q)
class_index = process_class(seg_class)
else:
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
class_index = process_class(seg_class)
process_query(seg_q)
paragraph_len = process_paragraph(seg_p)
input_ids = tokens input_ids = tokens
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
...@@ -405,11 +439,13 @@ def convert_examples_to_features(examples, ...@@ -405,11 +439,13 @@ def convert_examples_to_features(examples,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(0) input_mask.append(0)
segment_ids.append(0) segment_ids.append(seg_pad)
paragraph_mask.append(0)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length assert len(segment_ids) == max_seq_length
assert len(paragraph_mask) == max_seq_length
span_is_impossible = example.is_impossible span_is_impossible = example.is_impossible
start_position = None start_position = None
...@@ -429,13 +465,13 @@ def convert_examples_to_features(examples, ...@@ -429,13 +465,13 @@ def convert_examples_to_features(examples,
end_position = 0 end_position = 0
span_is_impossible = True span_is_impossible = True
else: else:
doc_offset = len(query_tokens) + 2 doc_offset = 0 if xlnet_format else len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
if is_training and span_is_impossible: if is_training and span_is_impossible:
start_position = 0 start_position = class_index
end_position = 0 end_position = class_index
if example_index < 20: if example_index < 20:
logging.info("*** Example ***") logging.info("*** Example ***")
...@@ -455,6 +491,9 @@ def convert_examples_to_features(examples, ...@@ -455,6 +491,9 @@ def convert_examples_to_features(examples,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("paragraph_mask: %s", " ".join(
[str(x) for x in paragraph_mask]))
logging.info("class_index: %d", class_index)
if is_training and span_is_impossible: if is_training and span_is_impossible:
logging.info("impossible example span") logging.info("impossible example span")
...@@ -488,8 +527,10 @@ def convert_examples_to_features(examples, ...@@ -488,8 +527,10 @@ def convert_examples_to_features(examples,
tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens], tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
input_ids=input_ids, input_ids=input_ids,
input_mask=input_mask, input_mask=input_mask,
paragraph_mask=paragraph_mask,
segment_ids=segment_ids, segment_ids=segment_ids,
paragraph_len=paragraph_len, paragraph_len=paragraph_len,
class_index=class_index,
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=span_is_impossible) is_impossible=span_is_impossible)
...@@ -609,6 +650,11 @@ def postprocess_output(all_examples, ...@@ -609,6 +650,11 @@ def postprocess_output(all_examples,
del do_lower_case, verbose del do_lower_case, verbose
# XLNet emits further predictions for start, end indexes and impossibility
# classifications.
xlnet_format = (hasattr(all_results[0], "start_indexes")
and all_results[0].start_indexes is not None)
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
example_index_to_features[feature.example_index].append(feature) example_index_to_features[feature.example_index].append(feature)
...@@ -636,19 +682,32 @@ def postprocess_output(all_examples, ...@@ -636,19 +682,32 @@ def postprocess_output(all_examples,
null_end_logit = 0 # the end logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id] result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant # if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative: if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0] if xlnet_format:
feature_null_score = result.class_logits
else:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null: if feature_null_score < score_null:
score_null = feature_null_score score_null = feature_null_score
min_null_feature_index = feature_index min_null_feature_index = feature_index
null_start_logit = result.start_logits[0] null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0] null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes: start_indexes_and_logits = _get_best_indexes_and_logits(
doc_offset = feature.tokens.index("[SEP]") + 1 result=result,
n_best_size=n_best_size,
start=True,
xlnet_format=xlnet_format)
end_indexes_and_logits = _get_best_indexes_and_logits(
result=result,
n_best_size=n_best_size,
start=False,
xlnet_format=xlnet_format)
doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
for start_index, start_logit in start_indexes_and_logits:
for end_index, end_logit in end_indexes_and_logits:
# We could hypothetically create invalid predictions, e.g., predict # We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all # that the start of the span is in the question. We throw out all
# invalid predictions. # invalid predictions.
...@@ -656,10 +715,6 @@ def postprocess_output(all_examples, ...@@ -656,10 +715,6 @@ def postprocess_output(all_examples,
continue continue
if end_index - doc_offset >= len(feature.tok_end_to_orig_index): if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
continue continue
# if start_index not in feature.tok_start_to_orig_index:
# continue
# if end_index not in feature.tok_end_to_orig_index:
# continue
if not feature.token_is_max_context.get(start_index, False): if not feature.token_is_max_context.get(start_index, False):
continue continue
if end_index < start_index: if end_index < start_index:
...@@ -672,10 +727,10 @@ def postprocess_output(all_examples, ...@@ -672,10 +727,10 @@ def postprocess_output(all_examples,
feature_index=feature_index, feature_index=feature_index,
start_index=start_index - doc_offset, start_index=start_index - doc_offset,
end_index=end_index - doc_offset, end_index=end_index - doc_offset,
start_logit=result.start_logits[start_index], start_logit=start_logit,
end_logit=result.end_logits[end_index])) end_logit=end_logit))
if version_2_with_negative: if version_2_with_negative and not xlnet_format:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
feature_index=min_null_feature_index, feature_index=min_null_feature_index,
...@@ -720,7 +775,7 @@ def postprocess_output(all_examples, ...@@ -720,7 +775,7 @@ def postprocess_output(all_examples,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it # if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative: if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(
...@@ -778,16 +833,30 @@ def write_to_json_files(json_records, json_file): ...@@ -778,16 +833,30 @@ def write_to_json_files(json_records, json_file):
writer.write(json.dumps(json_records, indent=4) + "\n") writer.write(json.dumps(json_records, indent=4) + "\n")
def _get_best_indexes(logits, n_best_size): def _get_best_indexes_and_logits(result,
"""Get the n-best logits from a list.""" n_best_size,
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) start=False,
xlnet_format=False):
best_indexes = [] """Generates the n-best indexes and logits from a list."""
for i in range(len(index_and_score)): if xlnet_format:
if i >= n_best_size: for i in range(n_best_size):
break for j in range(n_best_size):
best_indexes.append(index_and_score[i][0]) j_index = i * n_best_size + j
return best_indexes if start:
yield result.start_indexes[i], result.start_logits[i]
else:
yield result.end_indexes[j_index], result.end_logits[j_index]
else:
if start:
logits = result.start_logits
else:
logits = result.end_logits
index_and_score = sorted(enumerate(logits),
key=lambda x: x[1], reverse=True)
for i in range(len(index_and_score)):
if i >= n_best_size:
break
yield index_and_score[i]
def _compute_softmax(scores): def _compute_softmax(scores):
...@@ -816,12 +885,13 @@ def _compute_softmax(scores): ...@@ -816,12 +885,13 @@ def _compute_softmax(scores):
class FeatureWriter(object): class FeatureWriter(object):
"""Writes InputFeature to TF example file.""" """Writes InputFeature to TF example file."""
def __init__(self, filename, is_training): def __init__(self, filename, is_training, xlnet_format=False):
self.filename = filename self.filename = filename
self.is_training = is_training self.is_training = is_training
self.num_features = 0 self.num_features = 0
tf.io.gfile.makedirs(os.path.dirname(filename)) tf.io.gfile.makedirs(os.path.dirname(filename))
self._writer = tf.io.TFRecordWriter(filename) self._writer = tf.io.TFRecordWriter(filename)
self._xlnet_format = xlnet_format
def process_feature(self, feature): def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
...@@ -837,6 +907,9 @@ class FeatureWriter(object): ...@@ -837,6 +907,9 @@ class FeatureWriter(object):
features["input_ids"] = create_int_feature(feature.input_ids) features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if self._xlnet_format:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training: if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position]) features["start_positions"] = create_int_feature([feature.start_position])
...@@ -860,6 +933,7 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -860,6 +933,7 @@ def generate_tf_record_from_json_file(input_file_path,
do_lower_case=True, do_lower_case=True,
max_query_length=64, max_query_length=64,
doc_stride=128, doc_stride=128,
xlnet_format=False,
version_2_with_negative=False): version_2_with_negative=False):
"""Generates and saves training data into a tf record file.""" """Generates and saves training data into a tf record file."""
train_examples = read_squad_examples( train_examples = read_squad_examples(
...@@ -868,7 +942,8 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -868,7 +942,8 @@ def generate_tf_record_from_json_file(input_file_path,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative)
tokenizer = tokenization.FullSentencePieceTokenizer( tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file) sp_model_file=sp_model_file)
train_writer = FeatureWriter(filename=output_path, is_training=True) train_writer = FeatureWriter(
filename=output_path, is_training=True, xlnet_format=xlnet_format)
number_of_examples = convert_examples_to_features( number_of_examples = convert_examples_to_features(
examples=train_examples, examples=train_examples,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -877,6 +952,7 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -877,6 +952,7 @@ def generate_tf_record_from_json_file(input_file_path,
max_query_length=max_query_length, max_query_length=max_query_length,
is_training=True, is_training=True,
output_fn=train_writer.process_feature, output_fn=train_writer.process_feature,
xlnet_format=xlnet_format,
do_lower_case=do_lower_case) do_lower_case=do_lower_case)
train_writer.close() train_writer.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment