"fmoe/git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "a12ad553bc706003a9aaa969bb117a3e4ff1bee3"
Commit f93229b9 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 265510206
parent 1e48a60a
...@@ -27,6 +27,10 @@ def define_common_bert_flags(): ...@@ -27,6 +27,10 @@ def define_common_bert_flags():
flags.DEFINE_string('model_dir', None, ( flags.DEFINE_string('model_dir', None, (
'The directory where the model weights and training/evaluation summaries ' 'The directory where the model weights and training/evaluation summaries '
'are stored. If not specified, save to /tmp/bert20/.')) 'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_string('tpu', '', 'TPU address to connect to.') flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_string( flags.DEFINE_string(
'init_checkpoint', None, 'init_checkpoint', None,
......
...@@ -48,10 +48,6 @@ flags.DEFINE_string('train_data_path', None, ...@@ -48,10 +48,6 @@ flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.') 'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None, flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.') 'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
'model_export_path', None,
'Path to the directory, where trainined model will be '
'exported.')
# Model training specific flags. # Model training specific flags.
flags.DEFINE_string( flags.DEFINE_string(
'input_meta_data_path', None, 'input_meta_data_path', None,
......
...@@ -31,6 +31,7 @@ import tensorflow as tf ...@@ -31,6 +31,7 @@ import tensorflow as tf
from official.bert import bert_models from official.bert import bert_models
from official.bert import common_flags from official.bert import common_flags
from official.bert import input_pipeline from official.bert import input_pipeline
from official.bert import model_saving_utils
from official.bert import model_training_utils from official.bert import model_training_utils
from official.bert import modeling from official.bert import modeling
from official.bert import optimization from official.bert import optimization
...@@ -39,8 +40,13 @@ from official.bert import tokenization ...@@ -39,8 +40,13 @@ from official.bert import tokenization
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
flags.DEFINE_bool('do_train', False, 'Whether to run training.') flags.DEFINE_enum(
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.') 'mode', 'train', ['train', 'predict', 'export_only'],
'One of {"train", "predict", "export_only"}. `train`: '
'trains the model and evaluates in the meantime. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '', flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.') 'Training data path with train tfrecords.')
flags.DEFINE_string( flags.DEFINE_string(
...@@ -311,6 +317,26 @@ def predict_squad(strategy, input_meta_data): ...@@ -311,6 +317,26 @@ def predict_squad(strategy, input_meta_data):
verbose=FLAGS.verbose_logging) verbose=FLAGS.verbose_logging)
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
squad_model, _ = bert_models.squad_model(
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
...@@ -318,6 +344,10 @@ def main(_): ...@@ -318,6 +344,10 @@ def main(_):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
strategy = None strategy = None
if FLAGS.strategy_type == 'mirror': if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
...@@ -330,9 +360,9 @@ def main(_): ...@@ -330,9 +360,9 @@ def main(_):
else: else:
raise ValueError('The distribution strategy type is not supported: %s' % raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type) FLAGS.strategy_type)
if FLAGS.do_train: if FLAGS.mode == 'train':
train_squad(strategy, input_meta_data) train_squad(strategy, input_meta_data)
if FLAGS.do_predict: if FLAGS.mode == 'predict':
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment