Commit 24c619ff authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 277793274
parent fc2056bc
...@@ -168,7 +168,7 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training): ...@@ -168,7 +168,7 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training):
return dataset return dataset
def _get_input_iterator(input_fn, strategy): def get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator.""" """Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across # When training with TPU pods, datasets needs to be cloned across
......
...@@ -19,7 +19,6 @@ from __future__ import division ...@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import pickle
import random import random
from absl import app from absl import app
...@@ -54,16 +53,11 @@ flags.DEFINE_bool( ...@@ -54,16 +53,11 @@ flags.DEFINE_bool(
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def _get_spm_basename():
spm_basename = os.path.basename(FLAGS.spiece_model_file)
return spm_basename
def preprocess(): def preprocess():
"""Preprocesses SQUAD data.""" """Preprocesses SQUAD data."""
sp_model = spm.SentencePieceProcessor() sp_model = spm.SentencePieceProcessor()
sp_model.Load(FLAGS.spiece_model_file) sp_model.Load(FLAGS.spiece_model_file)
spm_basename = _get_spm_basename() spm_basename = os.path.basename(FLAGS.spiece_model_file)
if FLAGS.create_train_data: if FLAGS.create_train_data:
train_rec_file = os.path.join( train_rec_file = os.path.join(
FLAGS.output_dir, FLAGS.output_dir,
...@@ -97,39 +91,10 @@ def preprocess(): ...@@ -97,39 +91,10 @@ def preprocess():
if FLAGS.create_eval_data: if FLAGS.create_eval_data:
eval_examples = squad_utils.read_squad_examples( eval_examples = squad_utils.read_squad_examples(
FLAGS.predict_file, is_training=False) FLAGS.predict_file, is_training=False)
squad_utils.create_eval_data(spm_basename, sp_model, eval_examples,
eval_rec_file = os.path.join( FLAGS.max_seq_length, FLAGS.max_query_length,
FLAGS.output_dir, FLAGS.doc_stride, FLAGS.uncased,
"{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, FLAGS.output_dir)
FLAGS.max_seq_length,
FLAGS.max_query_length))
eval_feature_file = os.path.join(
FLAGS.output_dir,
"{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
FLAGS.max_seq_length,
FLAGS.max_query_length))
eval_writer = squad_utils.FeatureWriter(
filename=eval_rec_file, is_training=False)
eval_features = []
def append_feature(feature):
eval_features.append(feature)
eval_writer.process_feature(feature)
squad_utils.convert_examples_to_features(
examples=eval_examples,
sp_model=sp_model,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=False,
output_fn=append_feature,
uncased=FLAGS.uncased)
eval_writer.close()
with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
pickle.dump(eval_features, fout)
def main(_): def main(_):
......
...@@ -30,6 +30,7 @@ from absl import logging ...@@ -30,6 +30,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import # pylint: disable=unused-import
import sentencepiece as spm
from official.nlp import xlnet_config from official.nlp import xlnet_config
from official.nlp import xlnet_modeling as modeling from official.nlp import xlnet_modeling as modeling
from official.nlp.xlnet import common_flags from official.nlp.xlnet import common_flags
...@@ -51,6 +52,12 @@ flags.DEFINE_string( ...@@ -51,6 +52,12 @@ flags.DEFINE_string(
flags.DEFINE_integer( flags.DEFINE_integer(
"n_best_size", default=5, help="n best size for predictions.") "n_best_size", default=5, help="n best size for predictions.")
flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.") flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.")
# Data preprocessing config
flags.DEFINE_string(
"spiece_model_file", default=None, help="Sentence Piece model path.")
flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.")
flags.DEFINE_integer("max_query_length", default=64, help="Max query length.")
flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -92,23 +99,23 @@ class InputFeatures(object): ...@@ -92,23 +99,23 @@ class InputFeatures(object):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def run_evaluation(strategy, def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
test_input_fn, original_data, eval_steps, input_meta_data, model,
eval_steps, current_step, eval_summary_writer):
input_meta_data,
model,
step,
eval_summary_writer=None):
"""Run evaluation for SQUAD task. """Run evaluation for SQUAD task.
Args: Args:
strategy: distribution strategy. strategy: distribution strategy.
test_input_fn: input function for evaluation data. test_input_fn: input function for evaluation data.
eval_examples: tf.Examples of the evaluation set.
eval_features: Feature objects of the evaluation set.
original_data: The original json data for the evaluation set.
eval_steps: total number of evaluation steps. eval_steps: total number of evaluation steps.
input_meta_data: input meta data. input_meta_data: input meta data.
model: keras model object. model: keras model object.
step: current training step. current_step: current training step.
eval_summary_writer: summary writer used to record evaluation metrics. eval_summary_writer: summary writer used to record evaluation metrics.
Returns: Returns:
A float metric, F1 score. A float metric, F1 score.
""" """
...@@ -127,15 +134,8 @@ def run_evaluation(strategy, ...@@ -127,15 +134,8 @@ def run_evaluation(strategy,
_test_step_fn, args=(next(test_iterator),)) _test_step_fn, args=(next(test_iterator),))
return res, unique_ids return res, unique_ids
# pylint: disable=protected-access test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
test_iterator = data_utils._get_input_iterator(test_input_fn, strategy)
# pylint: enable=protected-access
cur_results = [] cur_results = []
eval_examples = squad_utils.read_squad_examples(
input_meta_data["predict_file"], is_training=False)
with tf.io.gfile.GFile(input_meta_data["predict_file"]) as f:
orig_data = json.load(f)["data"]
for _ in range(eval_steps): for _ in range(eval_steps):
results, unique_ids = _run_evaluation(test_iterator) results, unique_ids = _run_evaluation(test_iterator)
unique_ids = strategy.experimental_local_results(unique_ids) unique_ids = strategy.experimental_local_results(unique_ids)
...@@ -187,20 +187,19 @@ def run_evaluation(strategy, ...@@ -187,20 +187,19 @@ def run_evaluation(strategy,
"null_odds.json") "null_odds.json")
results = squad_utils.write_predictions( results = squad_utils.write_predictions(
eval_examples, input_meta_data["eval_features"], cur_results, eval_examples, eval_features, cur_results, input_meta_data["n_best_size"],
input_meta_data["n_best_size"], input_meta_data["max_answer_length"], input_meta_data["max_answer_length"], output_prediction_file,
output_prediction_file, output_nbest_file, output_null_log_odds_file, output_nbest_file, output_null_log_odds_file, original_data,
orig_data, input_meta_data["start_n_top"], input_meta_data["end_n_top"]) input_meta_data["start_n_top"], input_meta_data["end_n_top"])
# Log current results. # Log current results.
log_str = "Result | " log_str = "Result | "
for key, val in results.items(): for key, val in results.items():
log_str += "{} {} | ".format(key, val) log_str += "{} {} | ".format(key, val)
logging.info(log_str) logging.info(log_str)
if eval_summary_writer:
with eval_summary_writer.as_default(): with eval_summary_writer.as_default():
tf.summary.scalar("best_f1", results["best_f1"], step=step) tf.summary.scalar("best_f1", results["best_f1"], step=current_step)
tf.summary.scalar("best_exact", results["best_exact"], step=step) tf.summary.scalar("best_exact", results["best_exact"], step=current_step)
eval_summary_writer.flush() eval_summary_writer.flush()
return results["best_f1"] return results["best_f1"]
...@@ -254,24 +253,33 @@ def main(unused_argv): ...@@ -254,24 +253,33 @@ def main(unused_argv):
input_meta_data["end_n_top"] = FLAGS.end_n_top input_meta_data["end_n_top"] = FLAGS.end_n_top
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["predict_dir"] = FLAGS.predict_dir input_meta_data["predict_dir"] = FLAGS.predict_dir
input_meta_data["predict_file"] = FLAGS.predict_file
input_meta_data["n_best_size"] = FLAGS.n_best_size input_meta_data["n_best_size"] = FLAGS.n_best_size
input_meta_data["max_answer_length"] = FLAGS.max_answer_length input_meta_data["max_answer_length"] = FLAGS.max_answer_length
input_meta_data["test_feature_path"] = FLAGS.test_feature_path
input_meta_data["test_batch_size"] = FLAGS.test_batch_size input_meta_data["test_batch_size"] = FLAGS.test_batch_size
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["mem_len"] = FLAGS.mem_len
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
FLAGS.start_n_top, FLAGS.end_n_top) FLAGS.start_n_top, FLAGS.end_n_top)
eval_examples = squad_utils.read_squad_examples(
FLAGS.predict_file, is_training=False)
if FLAGS.test_feature_path:
logging.info("start reading pickle file...") logging.info("start reading pickle file...")
with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f: with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
eval_features = pickle.load(f) eval_features = pickle.load(f)
logging.info("finishing reading pickle file...") logging.info("finishing reading pickle file...")
input_meta_data["eval_features"] = eval_features else:
sp_model = spm.SentencePieceProcessor()
sp_model.Load(FLAGS.spiece_model_file)
spm_basename = os.path.basename(FLAGS.spiece_model_file)
eval_features = squad_utils.create_eval_data(
spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)
with tf.io.gfile.GFile(FLAGS.predict_file) as f:
original_data = json.load(f)["data"]
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_examples, eval_features, original_data,
eval_steps, input_meta_data) eval_steps, input_meta_data)
training_utils.train( training_utils.train(
......
...@@ -23,6 +23,8 @@ import collections ...@@ -23,6 +23,8 @@ import collections
import gc import gc
import json import json
import math import math
import os
import pickle
import re import re
import string import string
...@@ -922,3 +924,50 @@ class FeatureWriter(object): ...@@ -922,3 +924,50 @@ class FeatureWriter(object):
def close(self): def close(self):
self._writer.close() self._writer.close()
def create_eval_data(spm_basename,
sp_model,
eval_examples,
max_seq_length,
max_query_length,
doc_stride,
uncased,
output_dir=None):
"""Creates evaluation tfrecords."""
eval_features = []
eval_writer = None
if output_dir:
eval_rec_file = os.path.join(
output_dir,
"{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, max_seq_length,
max_query_length))
eval_feature_file = os.path.join(
output_dir,
"{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
max_seq_length,
max_query_length))
eval_writer = FeatureWriter(filename=eval_rec_file, is_training=False)
def append_feature(feature):
eval_features.append(feature)
if eval_writer:
eval_writer.process_feature(feature)
convert_examples_to_features(
examples=eval_examples,
sp_model=sp_model,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=append_feature,
uncased=uncased)
if eval_writer:
eval_writer.close()
with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
pickle.dump(eval_features, fout)
return eval_features
...@@ -110,9 +110,7 @@ def train( ...@@ -110,9 +110,7 @@ def train(
"`learning_rate_fn` are required parameters.") "`learning_rate_fn` are required parameters.")
if not model_dir: if not model_dir:
raise TypeError("Model directory must be specified.") raise TypeError("Model directory must be specified.")
# pylint: disable=protected-access train_iterator = data_utils.get_input_iterator(train_input_fn, strategy)
train_iterator = data_utils._get_input_iterator(train_input_fn, strategy)
# pylint: enable=protected-access
if not tf.io.gfile.exists(model_dir): if not tf.io.gfile.exists(model_dir):
tf.io.gfile.mkdir(model_dir) tf.io.gfile.mkdir(model_dir)
# Create summary writers # Create summary writers
......
...@@ -8,8 +8,9 @@ pandas>=0.22.0 ...@@ -8,8 +8,9 @@ pandas>=0.22.0
psutil>=5.4.3 psutil>=5.4.3
py-cpuinfo>=3.3.0 py-cpuinfo>=3.3.0
scipy>=0.19.1 scipy>=0.19.1
tensorflow-hub>=0.6.0
typing typing
tensorflow-hub sentencepiece
Cython Cython
matplotlib matplotlib
opencv-python-headless opencv-python-headless
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment