Commit f66efa5d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix squad script.

PiperOrigin-RevId: 269640109
parent bc984e50
...@@ -258,8 +258,6 @@ export MODEL_DIR=gs://some_bucket/my_output_dir ...@@ -258,8 +258,6 @@ export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_VERSION=v1.1 export SQUAD_VERSION=v1.1
python run_squad.py \ python run_squad.py \
--do_train=true \
--do_predict=true \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
...@@ -285,8 +283,6 @@ export SQUAD_DIR=gs://some_bucket/datasets ...@@ -285,8 +283,6 @@ export SQUAD_DIR=gs://some_bucket/datasets
export SQUAD_VERSION=v1.1 export SQUAD_VERSION=v1.1
python run_squad.py \ python run_squad.py \
--do_train=true \
--do_predict=true \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
......
...@@ -41,9 +41,11 @@ from official.utils.misc import keras_utils ...@@ -41,9 +41,11 @@ from official.utils.misc import keras_utils
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'train', ['train', 'predict', 'export_only'], 'mode', 'train_and_predict',
'One of {"train", "predict", "export_only"}. `train`: ' ['train_and_predict', 'train', 'predict', 'export_only'],
'trains the model and evaluates in the meantime. ' 'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. ' '`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside ' '`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.') 'model_dir and export a `SavedModel`.')
...@@ -370,9 +372,9 @@ def main(_): ...@@ -370,9 +372,9 @@ def main(_):
else: else:
raise ValueError('The distribution strategy type is not supported: %s' % raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type) FLAGS.strategy_type)
if FLAGS.mode == 'train': if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data) train_squad(strategy, input_meta_data)
if FLAGS.mode == 'predict': if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data) predict_squad(strategy, input_meta_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment