Commit fc2056bc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up primary_cpu_task. It is not necessary.

PiperOrigin-RevId: 277649735
parent 53eff257
......@@ -131,13 +131,6 @@ def hparam_flags_dict():
}
def primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return '/job:worker' if use_remote_tpu else ''
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to model_dir with provided checkpoint prefix."""
......@@ -204,7 +197,6 @@ class DistributedExecutor(object):
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False.
use_remote_tpu: If True, run on remote TPU mode.
"""
def __init__(self,
......@@ -212,14 +204,12 @@ class DistributedExecutor(object):
params,
model_fn,
loss_fn,
is_multi_host=False,
use_remote_tpu=False):
is_multi_host=False):
self._params = params
self._model_fn = model_fn
self._loss_fn = loss_fn
self._strategy = strategy
self._use_remote_tpu = use_remote_tpu
self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host
......@@ -867,11 +857,6 @@ class ExecutorBuilder(object):
raise ValueError('`strategy` should not be None. You need to specify '
'`strategy_type` in the builder contructor or directly '
'set the `strategy` property of the builder.')
if 'use_remote_tpu' not in kwargs:
use_remote_tpu = (
isinstance(self._strategy, tf.distribute.experimental.TPUStrategy) and
bool(self._strategy_config.tpu))
kwargs['use_remote_tpu'] = use_remote_tpu
return class_ctor(
strategy=self._strategy,
params=params,
......
......@@ -135,9 +135,6 @@ def get_raw_results(predictions):
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
primary_cpu_task = '/job:worker' if FLAGS.tpu else ''
with tf.device(primary_cpu_task):
predict_dataset = input_pipeline.create_squad_dataset(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment