Commit 238321c0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 298817372
parent 75d13042
...@@ -239,22 +239,74 @@ def run_keras_compile_fit(model_dir, ...@@ -239,22 +239,74 @@ def run_keras_compile_fit(model_dir,
return bert_model return bert_model
def get_predictions_and_labels(strategy, trained_model, eval_input_fn,
eval_steps):
"""Obtains predictions of trained model on evaluation data.
Note that list of labels is returned along with the predictions because the
order changes on distributing dataset over TPU pods.
Args:
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
eval_steps: Number of evaluation steps.
Returns:
predictions: List of predictions.
labels: List of gold labels corresponding to predictions.
"""
@tf.function
def test_step(iterator):
"""Computes predictions on distributed devices."""
def _test_step_fn(inputs):
"""Replicated predictions."""
inputs, labels = inputs
model_outputs = trained_model(inputs, training=False)
return model_outputs, labels
outputs, labels = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
return outputs, labels
def _run_evaluation(test_iterator):
"""Runs evaluation steps."""
preds, golds = list(), list()
for _ in range(eval_steps):
logits, labels = test_step(test_iterator)
for cur_logits, cur_labels in zip(logits, labels):
preds.extend(tf.math.argmax(cur_logits, axis=1).numpy())
golds.extend(cur_labels.numpy().tolist())
return preds, golds
test_iter = iter(
strategy.experimental_distribute_datasets_from_function(eval_input_fn))
predictions, labels = _run_evaluation(test_iter)
return predictions, labels
def export_classifier(model_export_path, input_meta_data, def export_classifier(model_export_path, input_meta_data,
restore_model_using_load_weights, restore_model_using_load_weights, bert_config, model_dir):
bert_config, model_dir):
"""Exports a trained model as a `SavedModel` for inference. """Exports a trained model as a `SavedModel` for inference.
Args: Args:
model_export_path: a string specifying the path to the SavedModel directory. model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model. input_meta_data: dictionary containing meta data about input and model.
restore_model_using_load_weights: Whether to use checkpoint.restore() API restore_model_using_load_weights: Whether to use checkpoint.restore() API
for custom checkpoint or to use model.load_weights() API. for custom checkpoint or to use model.load_weights() API. There are 2
There are 2 different ways to save checkpoints. One is using different ways to save checkpoints. One is using tf.train.Checkpoint and
tf.train.Checkpoint and another is using Keras model.save_weights(). another is using Keras model.save_weights(). Custom training loop
Custom training loop implementation uses tf.train.Checkpoint API implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
and Keras ModelCheckpoint callback internally uses model.save_weights() callback internally uses model.save_weights() API. Since these two API's
API. Since these two API's cannot be used together, model loading logic cannot be used together, model loading logic must be take into account how
must be take into account how model checkpoint was saved. model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers. bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation model_dir: The directory where the model weights and training/evaluation
summaries are stored. summaries are stored.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment