"src/graph/transform/to_block.h" did not exist on "bcd37684268a919f25aa5b9eb88f4e59aca1e7b4"
Commit 31e4a64d authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342129726
parent 82f46dc7
...@@ -262,9 +262,15 @@ def generate_squad_dataset(): ...@@ -262,9 +262,15 @@ def generate_squad_dataset():
assert FLAGS.squad_data_file assert FLAGS.squad_data_file
if FLAGS.tokenization == "WordPiece": if FLAGS.tokenization == "WordPiece":
return squad_lib_wp.generate_tf_record_from_json_file( return squad_lib_wp.generate_tf_record_from_json_file(
FLAGS.squad_data_file, FLAGS.vocab_file, FLAGS.train_data_output_path, input_file_path=FLAGS.squad_data_file,
FLAGS.max_seq_length, FLAGS.do_lower_case, FLAGS.max_query_length, vocab_file_path=FLAGS.vocab_file,
FLAGS.doc_stride, FLAGS.version_2_with_negative) output_path=FLAGS.train_data_output_path,
max_seq_length=FLAGS.max_seq_length,
do_lower_case=FLAGS.do_lower_case,
max_query_length=FLAGS.max_query_length,
doc_stride=FLAGS.doc_stride,
version_2_with_negative=FLAGS.version_2_with_negative,
xlnet_format=FLAGS.xlnet_format)
else: else:
assert FLAGS.tokenization == "SentencePiece" assert FLAGS.tokenization == "SentencePiece"
return squad_lib_sp.generate_tf_record_from_json_file( return squad_lib_sp.generate_tf_record_from_json_file(
......
...@@ -92,6 +92,8 @@ class InputFeatures(object): ...@@ -92,6 +92,8 @@ class InputFeatures(object):
input_ids, input_ids,
input_mask, input_mask,
segment_ids, segment_ids,
paragraph_mask=None,
class_index=None,
start_position=None, start_position=None,
end_position=None, end_position=None,
is_impossible=None): is_impossible=None):
...@@ -107,6 +109,8 @@ class InputFeatures(object): ...@@ -107,6 +109,8 @@ class InputFeatures(object):
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible self.is_impossible = is_impossible
self.paragraph_mask = paragraph_mask
self.class_index = class_index
class FeatureWriter(object): class FeatureWriter(object):
...@@ -134,6 +138,11 @@ class FeatureWriter(object): ...@@ -134,6 +138,11 @@ class FeatureWriter(object):
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if feature.paragraph_mask is not None:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index is not None:
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training: if self.is_training:
features["start_positions"] = create_int_feature([feature.start_position]) features["start_positions"] = create_int_feature([feature.start_position])
features["end_positions"] = create_int_feature([feature.end_position]) features["end_positions"] = create_int_feature([feature.end_position])
...@@ -238,6 +247,7 @@ def convert_examples_to_features(examples, ...@@ -238,6 +247,7 @@ def convert_examples_to_features(examples,
max_query_length, max_query_length,
is_training, is_training,
output_fn, output_fn,
xlnet_format=False,
batch_size=None): batch_size=None):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
...@@ -299,25 +309,54 @@ def convert_examples_to_features(examples, ...@@ -299,25 +309,54 @@ def convert_examples_to_features(examples,
token_to_orig_map = {} token_to_orig_map = {}
token_is_max_context = {} token_is_max_context = {}
segment_ids = [] segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0) # Paragraph mask used in XLNet.
for token in query_tokens: # 1 represents paragraph and class tokens.
tokens.append(token) # 0 represents query and other special tokens.
segment_ids.append(0) paragraph_mask = []
tokens.append("[SEP]")
segment_ids.append(0) # pylint: disable=cell-var-from-loop
def process_query(seg_q):
for i in range(doc_span.length): for token in query_tokens:
split_token_index = doc_span.start + i tokens.append(token)
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] segment_ids.append(seg_q)
paragraph_mask.append(0)
is_max_context = _check_is_max_context(doc_spans, doc_span_index, tokens.append("[SEP]")
split_token_index) segment_ids.append(seg_q)
token_is_max_context[len(tokens)] = is_max_context paragraph_mask.append(0)
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1) def process_paragraph(seg_p):
tokens.append("[SEP]") for i in range(doc_span.length):
segment_ids.append(1) split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(seg_p)
paragraph_mask.append(1)
tokens.append("[SEP]")
segment_ids.append(seg_p)
paragraph_mask.append(0)
def process_class(seg_class):
class_index = len(segment_ids)
tokens.append("[CLS]")
segment_ids.append(seg_class)
paragraph_mask.append(1)
return class_index
if xlnet_format:
seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
process_paragraph(seg_p)
process_query(seg_q)
class_index = process_class(seg_class)
else:
seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
class_index = process_class(seg_class)
process_query(seg_q)
process_paragraph(seg_p)
input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = tokenizer.convert_tokens_to_ids(tokens)
...@@ -329,11 +368,13 @@ def convert_examples_to_features(examples, ...@@ -329,11 +368,13 @@ def convert_examples_to_features(examples,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(0) input_mask.append(0)
segment_ids.append(0) segment_ids.append(seg_pad)
paragraph_mask.append(0)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length assert len(segment_ids) == max_seq_length
assert len(paragraph_mask) == max_seq_length
start_position = None start_position = None
end_position = None end_position = None
...@@ -350,7 +391,7 @@ def convert_examples_to_features(examples, ...@@ -350,7 +391,7 @@ def convert_examples_to_features(examples,
start_position = 0 start_position = 0
end_position = 0 end_position = 0
else: else:
doc_offset = len(query_tokens) + 2 doc_offset = 0 if xlnet_format else len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
...@@ -377,6 +418,9 @@ def convert_examples_to_features(examples, ...@@ -377,6 +418,9 @@ def convert_examples_to_features(examples,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("paragraph_mask: %s", " ".join(
[str(x) for x in paragraph_mask]))
logging.info("class_index: %d", class_index)
if is_training and example.is_impossible: if is_training and example.is_impossible:
logging.info("impossible example") logging.info("impossible example")
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
...@@ -390,6 +434,8 @@ def convert_examples_to_features(examples, ...@@ -390,6 +434,8 @@ def convert_examples_to_features(examples,
example_index=example_index, example_index=example_index,
doc_span_index=doc_span_index, doc_span_index=doc_span_index,
tokens=tokens, tokens=tokens,
paragraph_mask=paragraph_mask,
class_index=class_index,
token_to_orig_map=token_to_orig_map, token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context, token_is_max_context=token_is_max_context,
input_ids=input_ids, input_ids=input_ids,
...@@ -541,6 +587,7 @@ def postprocess_output(all_examples, ...@@ -541,6 +587,7 @@ def postprocess_output(all_examples,
do_lower_case, do_lower_case,
version_2_with_negative=False, version_2_with_negative=False,
null_score_diff_threshold=0.0, null_score_diff_threshold=0.0,
xlnet_format=False,
verbose=False): verbose=False):
"""Postprocess model output, to form predicton results.""" """Postprocess model output, to form predicton results."""
...@@ -570,45 +617,50 @@ def postprocess_output(all_examples, ...@@ -570,45 +617,50 @@ def postprocess_output(all_examples,
null_end_logit = 0 # the end logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id] result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant # if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative: if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0] if xlnet_format:
feature_null_score = result.class_logits
else:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null: if feature_null_score < score_null:
score_null = feature_null_score score_null = feature_null_score
min_null_feature_index = feature_index min_null_feature_index = feature_index
null_start_logit = result.start_logits[0] null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0] null_end_logit = result.end_logits[0]
for start_index in start_indexes: for (start_index, start_logit,
for end_index in end_indexes: end_index, end_logit) in _get_best_indexes_and_logits(
# We could hypothetically create invalid predictions, e.g., predict result=result,
# that the start of the span is in the question. We throw out all n_best_size=n_best_size,
# invalid predictions. xlnet_format=xlnet_format):
if start_index >= len(feature.tokens): # We could hypothetically create invalid predictions, e.g., predict
continue # that the start of the span is in the question. We throw out all
if end_index >= len(feature.tokens): # invalid predictions.
continue if start_index >= len(feature.tokens):
if start_index not in feature.token_to_orig_map: continue
continue if end_index >= len(feature.tokens):
if end_index not in feature.token_to_orig_map: continue
continue if start_index not in feature.token_to_orig_map:
if not feature.token_is_max_context.get(start_index, False): continue
continue if end_index not in feature.token_to_orig_map:
if end_index < start_index: continue
continue if not feature.token_is_max_context.get(start_index, False):
length = end_index - start_index + 1 continue
if length > max_answer_length: if end_index < start_index:
continue continue
prelim_predictions.append( length = end_index - start_index + 1
_PrelimPrediction( if length > max_answer_length:
feature_index=feature_index, continue
start_index=start_index, prelim_predictions.append(
end_index=end_index, _PrelimPrediction(
start_logit=result.start_logits[start_index], feature_index=feature_index,
end_logit=result.end_logits[end_index])) start_index=start_index,
end_index=end_index,
if version_2_with_negative: start_logit=start_logit,
end_logit=end_logit))
if version_2_with_negative and not xlnet_format:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
feature_index=min_null_feature_index, feature_index=min_null_feature_index,
...@@ -630,7 +682,7 @@ def postprocess_output(all_examples, ...@@ -630,7 +682,7 @@ def postprocess_output(all_examples,
if len(nbest) >= n_best_size: if len(nbest) >= n_best_size:
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction if pred.start_index > 0 or xlnet_format: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
...@@ -663,7 +715,7 @@ def postprocess_output(all_examples, ...@@ -663,7 +715,7 @@ def postprocess_output(all_examples,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it # if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative: if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(
...@@ -704,13 +756,18 @@ def postprocess_output(all_examples, ...@@ -704,13 +756,18 @@ def postprocess_output(all_examples,
# pytype: disable=attribute-error # pytype: disable=attribute-error
# predict "" iff the null score - the score of best non-null > threshold # predict "" iff the null score - the score of best non-null > threshold
if best_non_null_entry is not None: if best_non_null_entry is not None:
score_diff = score_null - best_non_null_entry.start_logit - ( if xlnet_format:
best_non_null_entry.end_logit) score_diff = score_null
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text all_predictions[example.qas_id] = best_non_null_entry.text
else:
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
else: else:
logging.warning("best_non_null_entry is None") logging.warning("best_non_null_entry is None")
scores_diff_json[example.qas_id] = score_null scores_diff_json[example.qas_id] = score_null
...@@ -822,16 +879,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False): ...@@ -822,16 +879,29 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
return output_text return output_text
def _get_best_indexes(logits, n_best_size): def _get_best_indexes_and_logits(result,
"""Get the n-best logits from a list.""" n_best_size,
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) xlnet_format=False):
"""Generates the n-best indexes and logits from a list."""
best_indexes = [] if xlnet_format:
for i in range(len(index_and_score)): # pylint: disable=consider-using-enumerate for i in range(n_best_size):
if i >= n_best_size: for j in range(n_best_size):
break j_index = i * n_best_size + j
best_indexes.append(index_and_score[i][0]) yield (result.start_indexes[i], result.start_logits[i],
return best_indexes result.end_indexes[j_index], result.end_logits[j_index])
else:
start_index_and_score = sorted(enumerate(result.start_logits),
key=lambda x: x[1], reverse=True)
end_index_and_score = sorted(enumerate(result.end_logits),
key=lambda x: x[1], reverse=True)
for i in range(len(start_index_and_score)):
if i >= n_best_size:
break
for j in range(len(end_index_and_score)):
if j >= n_best_size:
break
yield (start_index_and_score[i][0], start_index_and_score[i][1],
end_index_and_score[j][0], end_index_and_score[j][1])
def _compute_softmax(scores): def _compute_softmax(scores):
...@@ -864,7 +934,8 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -864,7 +934,8 @@ def generate_tf_record_from_json_file(input_file_path,
do_lower_case=True, do_lower_case=True,
max_query_length=64, max_query_length=64,
doc_stride=128, doc_stride=128,
version_2_with_negative=False): version_2_with_negative=False,
xlnet_format=False):
"""Generates and saves training data into a tf record file.""" """Generates and saves training data into a tf record file."""
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=input_file_path, input_file=input_file_path,
...@@ -880,7 +951,8 @@ def generate_tf_record_from_json_file(input_file_path, ...@@ -880,7 +951,8 @@ def generate_tf_record_from_json_file(input_file_path,
doc_stride=doc_stride, doc_stride=doc_stride,
max_query_length=max_query_length, max_query_length=max_query_length,
is_training=True, is_training=True,
output_fn=train_writer.process_feature) output_fn=train_writer.process_feature,
xlnet_format=xlnet_format)
train_writer.close() train_writer.close()
meta_data = { meta_data = {
......
...@@ -645,16 +645,11 @@ def postprocess_output(all_examples, ...@@ -645,16 +645,11 @@ def postprocess_output(all_examples,
do_lower_case, do_lower_case,
version_2_with_negative=False, version_2_with_negative=False,
null_score_diff_threshold=0.0, null_score_diff_threshold=0.0,
xlnet_format=False,
verbose=False): verbose=False):
"""Postprocess model output, to form predicton results.""" """Postprocess model output, to form predicton results."""
del do_lower_case, verbose del do_lower_case, verbose
# XLNet emits further predictions for start, end indexes and impossibility
# classifications.
xlnet_format = (hasattr(all_results[0], "start_indexes")
and all_results[0].start_indexes is not None)
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
example_index_to_features[feature.example_index].append(feature) example_index_to_features[feature.example_index].append(feature)
...@@ -904,9 +899,9 @@ class FeatureWriter(object): ...@@ -904,9 +899,9 @@ class FeatureWriter(object):
features["input_ids"] = create_int_feature(feature.input_ids) features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask) features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if feature.paragraph_mask: if feature.paragraph_mask is not None:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask) features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index: if feature.class_index is not None:
features["class_index"] = create_int_feature([feature.class_index]) features["class_index"] = create_int_feature([feature.class_index])
if self.is_training: if self.is_training:
......
...@@ -150,6 +150,15 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -150,6 +150,15 @@ class XLNetSpanLabeler(tf.keras.Model):
'span_labeling_activation': span_labeling_activation, 'span_labeling_activation': span_labeling_activation,
'initializer': initializer, 'initializer': initializer,
} }
network_config = network.get_config()
try:
input_width = network_config['inner_size']
self._xlnet_base = True
except KeyError:
# BertEncoder uses 'intermediate_size' due to legacy naming.
input_width = network_config['intermediate_size']
self._xlnet_base = False
self._network = network self._network = network
self._initializer = initializer self._initializer = initializer
self._start_n_top = start_n_top self._start_n_top = start_n_top
...@@ -157,7 +166,7 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -157,7 +166,7 @@ class XLNetSpanLabeler(tf.keras.Model):
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
self._activation = span_labeling_activation self._activation = span_labeling_activation
self.span_labeling = networks.XLNetSpanLabeling( self.span_labeling = networks.XLNetSpanLabeling(
input_width=network.get_config()['inner_size'], input_width=input_width,
start_n_top=self._start_n_top, start_n_top=self._start_n_top,
end_n_top=self._end_n_top, end_n_top=self._end_n_top,
activation=self._activation, activation=self._activation,
...@@ -165,17 +174,25 @@ class XLNetSpanLabeler(tf.keras.Model): ...@@ -165,17 +174,25 @@ class XLNetSpanLabeler(tf.keras.Model):
initializer=self._initializer) initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]): def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids'] input_type_ids = inputs['input_type_ids']
input_mask = inputs['input_mask'] input_mask = inputs['input_mask']
class_index = inputs['class_index'] class_index = inputs['class_index']
paragraph_mask = inputs['paragraph_mask'] paragraph_mask = inputs['paragraph_mask']
start_positions = inputs.get('start_positions', None) start_positions = inputs.get('start_positions', None)
attention_output, _ = self._network( if self._xlnet_base:
input_ids=input_ids, attention_output, _ = self._network(
segment_ids=segment_ids, input_ids=input_word_ids,
input_mask=input_mask) segment_ids=input_type_ids,
input_mask=input_mask)
else:
network_output_dict = self._network(dict(
input_word_ids=input_word_ids,
input_type_ids=input_type_ids,
input_mask=input_mask))
attention_output = network_output_dict['sequence_output']
outputs = self.span_labeling( outputs = self.span_labeling(
sequence_data=attention_output, sequence_data=attention_output,
class_index=class_index, class_index=class_index,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Question answering task.""" """Question answering task."""
import functools
import json import json
import os import os
from typing import List, Optional from typing import List, Optional
...@@ -143,6 +144,9 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -143,6 +144,9 @@ class QuestionAnsweringTask(base_task.Task):
eval_features.append(feature) eval_features.append(feature)
eval_writer.process_feature(feature) eval_writer.process_feature(feature)
# XLNet preprocesses SQuAD examples in a P, Q, class order whereas
# BERT preprocesses in a class, Q, P order.
xlnet_ordering = self.task_config.model.encoder.type == 'xlnet'
kwargs = dict( kwargs = dict(
examples=eval_examples, examples=eval_examples,
max_seq_length=params.seq_length, max_seq_length=params.seq_length,
...@@ -150,14 +154,14 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -150,14 +154,14 @@ class QuestionAnsweringTask(base_task.Task):
max_query_length=params.query_length, max_query_length=params.query_length,
is_training=False, is_training=False,
output_fn=_append_feature, output_fn=_append_feature,
batch_size=params.global_batch_size) batch_size=params.global_batch_size,
xlnet_format=xlnet_ordering)
if params.tokenization == 'SentencePiece': if params.tokenization == 'SentencePiece':
# squad_lib_sp requires one more argument 'do_lower_case'. # squad_lib_sp requires one more argument 'do_lower_case'.
kwargs['do_lower_case'] = params.do_lower_case kwargs['do_lower_case'] = params.do_lower_case
kwargs['tokenizer'] = tokenization.FullSentencePieceTokenizer( kwargs['tokenizer'] = tokenization.FullSentencePieceTokenizer(
sp_model_file=params.vocab_file) sp_model_file=params.vocab_file)
kwargs['xlnet_format'] = self.task_config.model.encoder.type == 'xlnet'
elif params.tokenization == 'WordPiece': elif params.tokenization == 'WordPiece':
kwargs['tokenizer'] = tokenization.FullTokenizer( kwargs['tokenizer'] = tokenization.FullTokenizer(
vocab_file=params.vocab_file, do_lower_case=params.do_lower_case) vocab_file=params.vocab_file, do_lower_case=params.do_lower_case)
...@@ -175,24 +179,25 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -175,24 +179,25 @@ class QuestionAnsweringTask(base_task.Task):
return eval_writer.filename, eval_examples, eval_features return eval_writer.filename, eval_examples, eval_features
def _dummy_data(self, params, _):
"""Returns dummy data."""
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32),
is_impossible=tf.constant(0, dtype=tf.int32))
return x, y
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task.""" """Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy': if params.input_path == 'dummy':
# Dummy training data for unit test.
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32),
is_impossible=tf.constant(0, dtype=tf.int32))
return (x, y)
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat() dataset = dataset.repeat()
dummy_data = functools.partial(self._dummy_data, params)
dataset = dataset.map( dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -278,6 +283,7 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -278,6 +283,7 @@ class QuestionAnsweringTask(base_task.Task):
self.task_config.validation_data.version_2_with_negative), self.task_config.validation_data.version_2_with_negative),
null_score_diff_threshold=( null_score_diff_threshold=(
self.task_config.null_score_diff_threshold), self.task_config.null_score_diff_threshold),
xlnet_format=self.task_config.validation_data.xlnet_format,
verbose=False)) verbose=False))
with tf.io.gfile.GFile(self.task_config.validation_data.input_path, with tf.io.gfile.GFile(self.task_config.validation_data.input_path,
...@@ -382,6 +388,24 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask): ...@@ -382,6 +388,24 @@ class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
'end_positions': end_logits, 'end_positions': end_logits,
}) })
def _dummy_data(self, params, _):
"""Returns dummy data."""
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
zero = tf.constant(0, dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
class_index=zero,
is_impossible=zero,
paragraph_mask=dummy_ids,
start_positions=tf.zeros((1), dtype=tf.int32))
y = dict(
start_positions=tf.zeros((1), dtype=tf.int32),
end_positions=tf.ones((1), dtype=tf.int32),
is_impossible=zero)
return x, y
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs features, _ = inputs
unique_ids = features.pop('unique_ids') unique_ids = features.pop('unique_ids')
...@@ -468,5 +492,6 @@ def predict(task: QuestionAnsweringTask, params: cfg.DataConfig, ...@@ -468,5 +492,6 @@ def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
task.task_config.validation_data.do_lower_case, task.task_config.validation_data.do_lower_case,
version_2_with_negative=(params.version_2_with_negative), version_2_with_negative=(params.version_2_with_negative),
null_score_diff_threshold=task.task_config.null_score_diff_threshold, null_score_diff_threshold=task.task_config.null_score_diff_threshold,
xlnet_format=task.task_config.validation_data.xlnet_format,
verbose=False)) verbose=False))
return all_predictions, all_nbest, scores_diff return all_predictions, all_nbest, scores_diff
...@@ -186,5 +186,93 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -186,5 +186,93 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(scores_diff) self.assertEmpty(scores_diff)
class XLNetQuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(XLNetQuestionAnsweringTaskTest, self).setUp()
self._encoder_config = encoders.EncoderConfig(
type="xlnet",
xlnet=encoders.XLNetEncoderConfig(vocab_size=30522, num_layers=1))
self._train_data_config = question_answering_dataloader.QADataConfig(
input_path="dummy", seq_length=128,
global_batch_size=2, xlnet_format=True)
val_data = {
"version":
"2.0",
"data": [{
"paragraphs": [{
"context":
"Sky is blue.",
"qas": [{
"question":
"What is blue?",
"id":
"1234",
"answers": [{
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}, {
"text": "Sky",
"answer_start": 0
}]
}]
}]
}]
}
self._val_input_path = os.path.join(self.get_temp_dir(), "val_data.json")
with tf.io.gfile.GFile(self._val_input_path, "w") as writer:
writer.write(json.dumps(val_data, indent=4) + "\n")
self._test_vocab = os.path.join(self.get_temp_dir(), "vocab.txt")
with tf.io.gfile.GFile(self._test_vocab, "w") as writer:
writer.write("[PAD]\n[UNK]\n[CLS]\n[SEP]\n[MASK]\nsky\nis\nblue\n")
def _get_validation_data_config(self):
return question_answering_dataloader.QADataConfig(
is_training=False,
input_path=self._val_input_path,
input_preprocessed_data_path=self.get_temp_dir(),
seq_length=128,
global_batch_size=2,
version_2_with_negative=True,
vocab_file=self._test_vocab,
tokenization="WordPiece",
do_lower_case=True,
xlnet_format=True)
def _run_task(self, config):
task = question_answering.XLNetQuestionAnsweringTask(config)
model = task.build_model()
metrics = task.build_metrics()
task.initialize(model)
train_dataset = task.build_inputs(config.train_data)
train_iterator = iter(train_dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(train_iterator), model, optimizer, metrics=metrics)
val_dataset = task.build_inputs(config.validation_data)
val_iterator = iter(val_dataset)
logs = task.validation_step(next(val_iterator), model, metrics=metrics)
# Mock that `logs` is from one replica.
logs = {x: (logs[x],) for x in logs}
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
def test_task(self):
config = question_answering.XLNetQuestionAnsweringConfig(
init_checkpoint="",
n_best_size=5,
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=self._get_validation_data_config())
self._run_task(config)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment