"vscode:/vscode.git/clone" did not exist on "7dc8c47596dab0eb3ae53edead6399fc8b96b720"
Commit ca2e6ae0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312923051
parent 2222cefc
......@@ -297,7 +297,7 @@ def squad_model(bert_config,
def classifier_model(bert_config,
num_labels,
max_seq_length,
max_seq_length=None,
final_layer_initializer=None,
hub_module_url=None,
hub_module_trainable=True):
......
......@@ -37,22 +37,23 @@ from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: '
'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
'restores the model to output predictions on the test set.')
flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.')
# Model training specific flags.
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
......@@ -125,9 +126,10 @@ def run_bert_classifier(strategy,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.end_lr, FLAGS.optimizer_type)
optimizer = optimization.create_optimizer(initial_lr,
steps_per_epoch * epochs,
warmup_steps, FLAGS.end_lr,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......@@ -214,9 +216,14 @@ def run_keras_compile_fit(model_dir,
summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=None,
step_counter=optimizer.iterations,
checkpoint_interval=0)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback]
......@@ -234,8 +241,10 @@ def run_keras_compile_fit(model_dir,
return bert_model
def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
eval_steps):
def get_predictions_and_labels(strategy,
trained_model,
eval_input_fn,
return_probs=False):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
......@@ -245,7 +254,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps.
return_probs: Whether to return probabilities of classes.
Returns:
predictions: List of predictions.
......@@ -259,11 +268,11 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
def _test_step_fn(inputs):
"""Replicated predictions."""
inputs, labels = inputs
model_outputs = trained_model(inputs, training=False)
return model_outputs, labels
logits = trained_model(inputs, training=False)
probabilities = tf.nn.softmax(logits)
return probabilities, labels
outputs, labels = strategy.run(
_test_step_fn, args=(next(iterator),))
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
......@@ -273,11 +282,18 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
def _run_evaluation(test_iterator):
"""Runs evaluation steps."""
preds, golds = list(), list()
for _ in range(eval_steps):
logits, labels = test_step(test_iterator)
for cur_logits, cur_labels in zip(logits, labels):
preds.extend(tf.math.argmax(cur_logits, axis=1).numpy())
try:
with tf.experimental.async_scope():
while True:
probabilities, labels = test_step(test_iterator)
for cur_probs, cur_labels in zip(probabilities, labels):
if return_probs:
preds.extend(cur_probs.numpy().tolist())
else:
preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
golds.extend(cur_labels.numpy().tolist())
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return preds, golds
test_iter = iter(
......@@ -287,21 +303,13 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
return predictions, labels
def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights, bert_config, model_dir):
def export_classifier(model_export_path, input_meta_data, bert_config,
model_dir):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API. There are 2
different ways to save checkpoints. One is using tf.train.Checkpoint and
another is using Keras model.save_weights(). Custom training loop
implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
callback internally uses model.save_weights() API. Since these two API's
cannot be used together, model loading logic must be take into account how
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
......@@ -317,14 +325,10 @@ def export_classifier(model_export_path, input_meta_data,
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'],
input_meta_data['max_seq_length'])[0]
bert_config, input_meta_data['num_labels'])[0]
model_saving_utils.export_bert_model(
model_export_path,
model=classifier_model,
checkpoint_dir=model_dir,
restore_model_using_load_weights=restore_model_using_load_weights)
model_export_path, model=classifier_model, checkpoint_dir=model_dir)
def run_bert(strategy,
......@@ -335,17 +339,6 @@ def run_bert(strategy,
init_checkpoint=None,
custom_callbacks=None):
"""Run BERT training."""
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier(FLAGS.model_export_path, input_meta_data,
FLAGS.use_keras_compile_fit,
model_config, FLAGS.model_dir)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
......@@ -364,7 +357,8 @@ def run_bert(strategy,
custom_callbacks = []
if FLAGS.log_steps:
custom_callbacks.append(keras_utils.TimeHistory(
custom_callbacks.append(
keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir))
......@@ -388,13 +382,8 @@ def run_bert(strategy,
custom_callbacks=custom_callbacks)
if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
model_saving_utils.export_bert_model(
FLAGS.model_export_path,
model=trained_model,
restore_model_using_load_weights=FLAGS.use_keras_compile_fit)
FLAGS.model_export_path, model=trained_model)
return trained_model
......@@ -412,25 +401,62 @@ def custom_main(custom_callbacks=None):
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
FLAGS.model_dir)
return
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
eval_input_fn = get_dataset_fn(
FLAGS.eval_data_path,
max_seq_length,
input_meta_data['max_seq_length'],
FLAGS.eval_batch_size,
is_training=False)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_bert(strategy, input_meta_data, bert_config, train_input_fn,
eval_input_fn, custom_callbacks=custom_callbacks)
if FLAGS.mode == 'predict':
with strategy.scope():
classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, _ = get_predictions_and_labels(
strategy, classifier_model, eval_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
input_meta_data['max_seq_length'],
FLAGS.train_batch_size,
is_training=True)
run_bert(
strategy,
input_meta_data,
bert_config,
train_input_fn,
eval_input_fn,
custom_callbacks=custom_callbacks)
def main(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment