Commit 31e4a64d authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342129726
parent 82f46dc7
......@@ -262,9 +262,15 @@ def generate_squad_dataset():
assert FLAGS.squad_data_file
if FLAGS.tokenization == "WordPiece":
return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.version_2_with_negative)
input_file_path=FLAGS.squad_data_file,
vocab_file_path=FLAGS.vocab_file,
output_path=FLAGS.train_data_output_path,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
doc_stride=FLAGS.doc_stride,
version_2_with_negative=FLAGS.version_2_with_negative,
xlnet_format=FLAGS.xlnet_format)
else:
assert FLAGS.tokenization == "SentencePiece"
return squad_lib_sp.generate_tf_record_from_json_file(
......
......@@ -92,6 +92,8 @@ class InputFeatures(object):
input_ids,
input_mask,
segment_ids,
paragraph_mask=None,
class_index=None,
start_position=None,
end_position=None,
is_impossible=None):
......@@ -107,6 +109,8 @@ class InputFeatures(object):
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.paragraph_mask = paragraph_mask
self.class_index = class_index
class FeatureWriter(object):
......@@ -134,6 +138,11 @@ class FeatureWriter(object):
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if feature.paragraph_mask is not None:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index is not None:
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position])
......@@ -238,6 +247,7 @@ def convert_examples_to_features(examples,
max_query_length,
is_training,
output_fn,
xlnet_format=False,
batch_size=None):
"""Loads a data file into a list of `InputBatch`s."""
......@@ -299,25 +309,54 @@ def convert_examples_to_features(examples,
token_to_orig_map = {}
token_is_max_context = {}
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
# Paragraph mask used in XLNet.
# 1 represents paragraph and class tokens.
# 0 represents query and other special tokens.
paragraph_mask = []
# pylint: disable=cell-var-from-loop
def process_query(seg_q):
for token in query_tokens:
tokens.append(token)
segment_ids.append(seg_q)
paragraph_mask.append(0)
tokens.append("[SEP]")
segment_ids.append(seg_q)
paragraph_mask.append(0)
def process_paragraph(seg_p):
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(seg_p)
paragraph_mask.append(1)
tokens.append("[SEP]")
segment_ids.append(seg_p)
paragraph_mask.append(0)
def process_class(seg_class):
class_index = len(segment_ids)
tokens.append("[CLS]")
segment_ids.append(seg_class)
paragraph_mask.append(1)
return class_index
if xlnet_format:
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
process_paragraph(seg_p)
process_query(seg_q)
class_index = process_class(seg_class)
else:
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
class_index = process_class(seg_class)
process_query(seg_q)
process_paragraph(seg_p)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
......@@ -329,11 +368,13 @@ def convert_examples_to_features(examples,
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
segment_ids.append(seg_pad)
paragraph_mask.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(paragraph_mask) == max_seq_length
start_position = None
end_position = None
......@@ -350,7 +391,7 @@ def convert_examples_to_features(examples,
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2
doc_offset = 0 if xlnet_format else len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
......@@ -377,6 +418,9 @@ def convert_examples_to_features(examples,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("paragraph_mask: %s", " ".join(
[str(x) for x in paragraph_mask]))
logging.info("class_index: %d", class_index)
if is_training and example.is_impossible:
logging.info("impossible example")
if is_training and not example.is_impossible:
......@@ -390,6 +434,8 @@ def convert_examples_to_features(examples,
example_index=example_index,
doc_span_index=doc_span_index,
tokens=tokens,
paragraph_mask=paragraph_mask,
class_index=class_index,
token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context,
input_ids=input_ids,
......@@ -541,6 +587,7 @@ def postprocess_output(all_examples,
do_lower_case,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
xlnet_format=False,
verbose=False):
"""Postprocess model output, to form predicton results."""
......@@ -570,45 +617,50 @@ def postprocess_output(all_examples,
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if xlnet_format:
feature_null_score = result.class_logits
else:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
for (start_index, start_logit,
end_index, end_logit) in _get_best_indexes_and_logits(
result=result,
n_best_size=n_best_size,
xlnet_format=xlnet_format):
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=start_logit,
end_logit=end_logit))
if version_2_with_negative and not xlnet_format:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
......@@ -630,7 +682,7 @@ def postprocess_output(all_examples,
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
if pred.start_index > 0 or xlnet_format: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
......@@ -663,7 +715,7 @@ def postprocess_output(all_examples,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative:
if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
......@@ -704,13 +756,18 @@ def postprocess_output(all_examples,
# pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold
if best_non_null_entry is not None:
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
if xlnet_format:
score_diff = score_null
scores_diff_json[example.qas_id] = score_diff
all_predictions[example.qas_id] = best_non_null_entry.text
else:
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
else:
logging.warning("best_non_null_entry is None")
scores_diff_json[example.qas_id] = score_null
......@@ -822,16 +879,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
return output_text
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def _get_best_indexes_and_logits(result,
n_best_size,
xlnet_format=False):
"""Generates the n-best indexes and logits from a list."""
if xlnet_format:
for i in range(n_best_size):
for j in range(n_best_size):
j_index = i * n_best_size + j
yield (result.start_indexes[i], result.start_logits[i],
result.end_indexes[j_index], result.end_logits[j_index])
else:
start_index_and_score = sorted(enumerate(result.start_logits),
key=lambda x: x[1], reverse=True)
end_index_and_score = sorted(enumerate(result.end_logits),
key=lambda x: x[1], reverse=True)
for i in range(len(start_index_and_score)):
if i >= n_best_size:
break
for j in range(len(end_index_and_score)):
if j >= n_best_size:
break
yield (start_index_and_score[i][0], start_index_and_score[i][1],
end_index_and_score[j][0], end_index_and_score[j][1])
def _compute_softmax(scores):
......@@ -864,7 +934,8 @@ def generate_tf_record_from_json_file(input_file_path,
do_lower_case=True,
max_query_length=64,
doc_stride=128,
version_2_with_negative=False):
version_2_with_negative=False,
xlnet_format=False):
"""Generates and saves training data into a tf record file."""
train_examples = read_squad_examples(
input_file=input_file_path,
......@@ -880,7 +951,8 @@ def generate_tf_record_from_json_file(input_file_path,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=True,
output_fn=train_writer.process_feature)
output_fn=train_writer.process_feature,
xlnet_format=xlnet_format)
train_writer.close()
meta_data = {
......
......@@ -645,16 +645,11 @@ def postprocess_output(all_examples,
do_lower_case,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
xlnet_format=False,
verbose=False):
"""Postprocess model output, to form predicton results."""
del do_lower_case, verbose
# XLNet emits further predictions for start, end indexes and impossibility
# classifications.
xlnet_format = (hasattr(all_results[0], "start_indexes")
and all_results[0].start_indexes is not None)
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
......@@ -904,9 +899,9 @@ class FeatureWriter(object):
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if feature.paragraph_mask:
if feature.paragraph_mask is not None:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index:
if feature.class_index is not None:
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training:
......
......@@ -150,6 +150,15 @@ class XLNetSpanLabeler(tf.keras.Model):
'span_labeling_activation': span_labeling_activation,
'initializer': initializer,
}
network_config = network.get_config()
try:
input_width = network_config['inner_size']
self._xlnet_base = True
except KeyError:
# BertEncoder uses 'intermediate_size' due to legacy naming.
input_width = network_config['intermediate_size']
self._xlnet_base = False
self._network = network
self._initializer = initializer
self._start_n_top = start_n_top
......@@ -157,7 +166,7 @@ class XLNetSpanLabeler(tf.keras.Model):
self._dropout_rate = dropout_rate
self._activation = span_labeling_activation
self.span_labeling = networks.XLNetSpanLabeling(
input_width=network.get_config()['inner_size'],
input_width=input_width,
start_n_top=self._start_n_top,
end_n_top=self._end_n_top,
activation=self._activation,
......@@ -165,17 +174,25 @@ class XLNetSpanLabeler(tf.keras.Model):
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids']
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
input_mask = inputs['input_mask']
class_index = inputs['class_index']
paragraph_mask = inputs['paragraph_mask']
start_positions = inputs.get('start_positions', None)
attention_output, _ = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask)
if self._xlnet_base:
attention_output, _ = self._network(
input_ids=input_word_ids,
segment_ids=input_type_ids,
input_mask=input_mask)
else:
network_output_dict = self._network(dict(
input_word_ids=input_word_ids,
input_type_ids=input_type_ids,
input_mask=input_mask))
attention_output = network_output_dict['sequence_output']
outputs = self.span_labeling(
sequence_data=attention_output,
class_index=class_index,
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import functools
import json
import os
from typing import List, Optional
......@@ -143,6 +144,9 @@ class QuestionAnsweringTask(base_task.Task):
eval_features.append(feature)
eval_writer.process_feature(feature)
# XLNet preprocesses SQuAD examples in a P, Q, class order whereas
# BERT preprocesses in a class, Q, P order.
xlnet_ordering = self.task_config.model.encoder.type == 'xlnet'
kwargs = dict(
examples=eval_examples,
max_seq_length=params.seq_length,
......@@ -150,14 +154,14 @@ class QuestionAnsweringTask(base_task.Task):
max_query_length=params.query_length,
is_training=False,
output_fn=_append_feature,
batch_size=params.global_batch_size)
batch_size=params.global_batch_size,
xlnet_format=xlnet_ordering)
if params.tokenization == 'SentencePiece':
# squad_lib_sp requires one more argument 'do_lower_case'.
kwargs['do_lower_case'] = params.do_lower_case
kwargs['tokenizer'] = tokenization.FullSentencePieceTokenizer(
sp_model_file=params.vocab_file)
kwargs['xlnet_format'] = self.task_config.model.encoder.type == 'xlnet'
elif params.tokenization == 'WordPiece':
kwargs['tokenizer'] = tokenization.FullTokenizer(
vocab_file=params.vocab_file, do_lower_case=params.do_lower_case)
......@@ -175,24 +179,25 @@ class QuestionAnsweringTask(base_task.Task):
return eval_writer.filename, eval_examples, eval_features
def _dummy_data(self, params, _):
"""Returns dummy data."""
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32),
is_impossible=tf.constant(0, dtype=tf.int32))
return x, y
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy':
# Dummy training data for unit test.
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32),
is_impossible=tf.constant(0, dtype=tf.int32))
return (x, y)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dummy_data = functools.partial(self._dummy_data, params)
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
......@@ -278,6 +283,7 @@ class QuestionAnsweringTask(base_task.Task):
self.task_config.validation_data.version_2_with_negative),
null_score_diff_threshold=(
self.task_config.null_score_diff_threshold),
xlnet_format=self.task_config.validation_data.xlnet_format,
verbose=False))
with tf.io.gfile.GFile(self.task_config.validation_data.input_path,
......@@ -382,6 +388,24 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
'end_positions': end_logits,
})
def _dummy_data(self, params, _):
"""Returns dummy data."""
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
zero = tf.constant(0, dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
class_index=zero,
is_impossible=zero,
paragraph_mask=dummy_ids,
start_positions=tf.zeros((1), dtype=tf.int32))
y = dict(
start_positions=tf.zeros((1), dtype=tf.int32),
end_positions=tf.ones((1), dtype=tf.int32),
is_impossible=zero)
return x, y
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs
unique_ids = features.pop('unique_ids')
......@@ -468,5 +492,6 @@ def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
task.task_config.validation_data.do_lower_case,
version_2_with_negative=(params.version_2_with_negative),
null_score_diff_threshold=task.task_config.null_score_diff_threshold,
xlnet_format=task.task_config.validation_data.xlnet_format,
verbose=False))
return all_predictions, all_nbest, scores_diff
......@@ -186,5 +186,93 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(scores_diff)
class XLNetQuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(XLNetQuestionAnsweringTaskTest, self).setUp()
self._encoder_config = encoders.EncoderConfig(
type="xlnet",
xlnet=encoders.XLNetEncoderConfig(vocab_size=30522, num_layers=1))
self._train_data_config = question_answering_dataloader.QADataConfig(
input_path="dummy", seq_length=128,
global_batch_size=2, xlnet_format=True)
val_data = {
"version":
"2.0",
"data": [{
"paragraphs": [{
"context":
"Sky is blue.",
"qas": [{
"question":
"What is blue?",
"id":
"1234",
"answers": [{
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}]
}]
}]
}]
}
self._val_input_path = os.path.join(self.get_temp_dir(), "val_data.json")
with tf.io.gfile.GFile(self._val_input_path, "w") as writer:
writer.write(json.dumps(val_data, indent=4) + "\n")
self._test_vocab = os.path.join(self.get_temp_dir(), "vocab.txt")
with tf.io.gfile.GFile(self._test_vocab, "w") as writer:
writer.write("[PAD]\n[UNK]\n[CLS]\n[SEP]\n[MASK]\nsky\nis\nblue\n")
def _get_validation_data_config(self):
return question_answering_dataloader.QADataConfig(
is_training=False,
input_path=self._val_input_path,
input_preprocessed_data_path=self.get_temp_dir(),
seq_length=128,
global_batch_size=2,
version_2_with_negative=True,
vocab_file=self._test_vocab,
tokenization="WordPiece",
do_lower_case=True,
xlnet_format=True)
def _run_task(self, config):
task = question_answering.XLNetQuestionAnsweringTask(config)
model = task.build_model()
metrics = task.build_metrics()
task.initialize(model)
train_dataset = task.build_inputs(config.train_data)
train_iterator = iter(train_dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(train_iterator), model, optimizer, metrics=metrics)
val_dataset = task.build_inputs(config.validation_data)
val_iterator = iter(val_dataset)
logs = task.validation_step(next(val_iterator), model, metrics=metrics)
# Mock that `logs` is from one replica.
logs = {x: (logs[x],) for x in logs}
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
def test_task(self):
config = question_answering.XLNetQuestionAnsweringConfig(
init_checkpoint="",
n_best_size=5,
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
self._run_task(config)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment