Commit ca2e6ae0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312923051
parent 2222cefc
...@@ -297,7 +297,7 @@ def squad_model(bert_config, ...@@ -297,7 +297,7 @@ def squad_model(bert_config,
def classifier_model(bert_config, def classifier_model(bert_config,
num_labels, num_labels,
max_seq_length, max_seq_length=None,
final_layer_initializer=None, final_layer_initializer=None,
hub_module_url=None, hub_module_url=None,
hub_module_trainable=True): hub_module_trainable=True):
......
...@@ -37,22 +37,23 @@ from official.nlp.bert import model_training_utils ...@@ -37,22 +37,23 @@ from official.nlp.bert import model_training_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'], 'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
'One of {"train_and_eval", "export_only"}. `train_and_eval`: ' 'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
'trains the model and evaluates in the meantime. ' 'trains the model and evaluates in the meantime. '
'`export_only`: will take the latest checkpoint inside ' '`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.') 'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
'restores the model to output predictions on the test set.')
flags.DEFINE_string('train_data_path', None, flags.DEFINE_string('train_data_path', None,
'Path to training data for BERT classifier.') 'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None, flags.DEFINE_string('eval_data_path', None,
'Path to evaluation data for BERT classifier.') 'Path to evaluation data for BERT classifier.')
# Model training specific flags.
flags.DEFINE_string( flags.DEFINE_string(
'input_meta_data_path', None, 'input_meta_data_path', None,
'Path to file that contains meta data about input ' 'Path to file that contains meta data about input '
'to be used for training and evaluation.') 'to be used for training and evaluation.')
flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
...@@ -125,9 +126,10 @@ def run_bert_classifier(strategy, ...@@ -125,9 +126,10 @@ def run_bert_classifier(strategy,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)) hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(initial_lr,
initial_lr, steps_per_epoch * epochs, warmup_steps, steps_per_epoch * epochs,
FLAGS.end_lr, FLAGS.optimizer_type) warmup_steps, FLAGS.end_lr,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
...@@ -214,9 +216,14 @@ def run_keras_compile_fit(model_dir, ...@@ -214,9 +216,14 @@ def run_keras_compile_fit(model_dir,
summary_dir = os.path.join(model_dir, 'summaries') summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(model_dir, 'checkpoint') checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_manager = tf.train.CheckpointManager(
checkpoint_path, save_weights_only=True) checkpoint,
directory=model_dir,
max_to_keep=None,
step_counter=optimizer.iterations,
checkpoint_interval=0)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
if custom_callbacks is not None: if custom_callbacks is not None:
custom_callbacks += [summary_callback, checkpoint_callback] custom_callbacks += [summary_callback, checkpoint_callback]
...@@ -234,8 +241,10 @@ def run_keras_compile_fit(model_dir, ...@@ -234,8 +241,10 @@ def run_keras_compile_fit(model_dir,
return bert_model return bert_model
def get_predictions_and_labels(strategy, trained_model, eval_input_fn, def get_predictions_and_labels(strategy,
eval_steps): trained_model,
eval_input_fn,
return_probs=False):
"""Obtains predictions of trained model on evaluation data. """Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the Note that list of labels is returned along with the predictions because the
...@@ -245,7 +254,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn, ...@@ -245,7 +254,7 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
strategy: Distribution strategy. strategy: Distribution strategy.
trained_model: Trained model with preloaded weights. trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data. eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps. return_probs: Whether to return probabilities of classes.
Returns: Returns:
predictions: List of predictions. predictions: List of predictions.
...@@ -259,11 +268,11 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn, ...@@ -259,11 +268,11 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
def _test_step_fn(inputs): def _test_step_fn(inputs):
"""Replicated predictions.""" """Replicated predictions."""
inputs, labels = inputs inputs, labels = inputs
model_outputs = trained_model(inputs, training=False) logits = trained_model(inputs, training=False)
return model_outputs, labels probabilities = tf.nn.softmax(logits)
return probabilities, labels
outputs, labels = strategy.run( outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits # outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results, outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs) outputs)
...@@ -273,11 +282,18 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn, ...@@ -273,11 +282,18 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
def _run_evaluation(test_iterator): def _run_evaluation(test_iterator):
"""Runs evaluation steps.""" """Runs evaluation steps."""
preds, golds = list(), list() preds, golds = list(), list()
for _ in range(eval_steps): try:
logits, labels = test_step(test_iterator) with tf.experimental.async_scope():
for cur_logits, cur_labels in zip(logits, labels): while True:
preds.extend(tf.math.argmax(cur_logits, axis=1).numpy()) probabilities, labels = test_step(test_iterator)
golds.extend(cur_labels.numpy().tolist()) for cur_probs, cur_labels in zip(probabilities, labels):
if return_probs:
preds.extend(cur_probs.numpy().tolist())
else:
preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
golds.extend(cur_labels.numpy().tolist())
except (StopIteration, tf.errors.OutOfRangeError):
tf.experimental.async_clear_error()
return preds, golds return preds, golds
test_iter = iter( test_iter = iter(
...@@ -287,21 +303,13 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn, ...@@ -287,21 +303,13 @@ def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
return predictions, labels return predictions, labels
def export_classifier(model_export_path, input_meta_data, def export_classifier(model_export_path, input_meta_data, bert_config,
restore_model_using_load_weights, bert_config, model_dir): model_dir):
"""Exports a trained model as a `SavedModel` for inference. """Exports a trained model as a `SavedModel` for inference.
Args: Args:
model_export_path: a string specifying the path to the SavedModel directory. model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model. input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API. There are 2
different ways to save checkpoints. One is using tf.train.Checkpoint and
another is using Keras model.save_weights(). Custom training loop
implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
callback internally uses model.save_weights() API. Since these two API's
cannot be used together, model loading logic must be take into account how
model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers. bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation model_dir: The directory where the model weights and training/evaluation
summaries are stored. summaries are stored.
...@@ -317,14 +325,10 @@ def export_classifier(model_export_path, input_meta_data, ...@@ -317,14 +325,10 @@ def export_classifier(model_export_path, input_meta_data,
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'], bert_config, input_meta_data['num_labels'])[0]
input_meta_data['max_seq_length'])[0]
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
model_export_path, model_export_path, model=classifier_model, checkpoint_dir=model_dir)
model=classifier_model,
checkpoint_dir=model_dir,
restore_model_using_load_weights=restore_model_using_load_weights)
def run_bert(strategy, def run_bert(strategy,
...@@ -335,17 +339,6 @@ def run_bert(strategy, ...@@ -335,17 +339,6 @@ def run_bert(strategy,
init_checkpoint=None, init_checkpoint=None,
custom_callbacks=None): custom_callbacks=None):
"""Run BERT training.""" """Run BERT training."""
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
export_classifier(FLAGS.model_export_path, input_meta_data,
FLAGS.use_keras_compile_fit,
model_config, FLAGS.model_dir)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype()) performance.set_mixed_precision_policy(common_flags.dtype())
...@@ -364,10 +357,11 @@ def run_bert(strategy, ...@@ -364,10 +357,11 @@ def run_bert(strategy,
custom_callbacks = [] custom_callbacks = []
if FLAGS.log_steps: if FLAGS.log_steps:
custom_callbacks.append(keras_utils.TimeHistory( custom_callbacks.append(
batch_size=FLAGS.train_batch_size, keras_utils.TimeHistory(
log_steps=FLAGS.log_steps, batch_size=FLAGS.train_batch_size,
logdir=FLAGS.model_dir)) log_steps=FLAGS.log_steps,
logdir=FLAGS.model_dir))
trained_model = run_bert_classifier( trained_model = run_bert_classifier(
strategy, strategy,
...@@ -388,13 +382,8 @@ def run_bert(strategy, ...@@ -388,13 +382,8 @@ def run_bert(strategy,
custom_callbacks=custom_callbacks) custom_callbacks=custom_callbacks)
if FLAGS.model_export_path: if FLAGS.model_export_path:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
FLAGS.model_export_path, FLAGS.model_export_path, model=trained_model)
model=trained_model,
restore_model_using_load_weights=FLAGS.use_keras_compile_fit)
return trained_model return trained_model
...@@ -412,25 +401,62 @@ def custom_main(custom_callbacks=None): ...@@ -412,25 +401,62 @@ def custom_main(custom_callbacks=None):
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.mode == 'export_only':
export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
FLAGS.model_dir)
return
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
max_seq_length = input_meta_data['max_seq_length']
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
eval_input_fn = get_dataset_fn( eval_input_fn = get_dataset_fn(
FLAGS.eval_data_path, FLAGS.eval_data_path,
max_seq_length, input_meta_data['max_seq_length'],
FLAGS.eval_batch_size, FLAGS.eval_batch_size,
is_training=False) is_training=False)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.mode == 'predict':
run_bert(strategy, input_meta_data, bert_config, train_input_fn, with strategy.scope():
eval_input_fn, custom_callbacks=custom_callbacks) classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'])[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
tf.train.latest_checkpoint(FLAGS.model_dir))
assert latest_checkpoint_file
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, _ = get_predictions_and_labels(
strategy, classifier_model, eval_input_fn, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
for probabilities in preds:
output_line = '\t'.join(
str(class_probability)
for class_probability in probabilities) + '\n'
writer.write(output_line)
return
if FLAGS.mode != 'train_and_eval':
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
input_meta_data['max_seq_length'],
FLAGS.train_batch_size,
is_training=True)
run_bert(
strategy,
input_meta_data,
bert_config,
train_input_fn,
eval_input_fn,
custom_callbacks=custom_callbacks)
def main(_): def main(_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment