Commit cd85fd8a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use get_primary_cpu_task from tpu_lib

PiperOrigin-RevId: 263874363
parent b1d9ac5b
...@@ -25,18 +25,12 @@ from absl import logging ...@@ -25,18 +25,12 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity from tensorflow.python.util import object_identity
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import tpu_lib
_SUMMARY_TXT = 'training_summary.txt' _SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10 _MIN_SUMMARY_STEPS = 10
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return '/job:worker' if use_remote_tpu else ''
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix.""" """Saves model to with provided checkpoint prefix."""
...@@ -195,7 +189,7 @@ def run_customized_training_loop( ...@@ -195,7 +189,7 @@ def run_customized_training_loop(
# To reduce unnecessary send/receive input pipeline operation, we place input # To reduce unnecessary send/receive input pipeline operation, we place input
# pipeline ops in worker task. # pipeline ops in worker task.
with tf.device(get_primary_cpu_task(use_remote_tpu)): with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
......
...@@ -210,7 +210,7 @@ def run_bert(strategy, input_meta_data): ...@@ -210,7 +210,7 @@ def run_bert(strategy, input_meta_data):
run_eagerly=FLAGS.run_eagerly) run_eagerly=FLAGS.run_eagerly)
if FLAGS.model_export_path: if FLAGS.model_export_path:
with tf.device(model_training_utils.get_primary_cpu_task(use_remote_tpu)): with tf.device(tpu_lib.get_primary_cpu_task(use_remote_tpu)):
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model) FLAGS.model_export_path, model=trained_model)
return trained_model return trained_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment