Commit fc2056bc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up primary_cpu_task. It is not necessary.

PiperOrigin-RevId: 277649735
parent 53eff257
...@@ -131,13 +131,6 @@ def hparam_flags_dict(): ...@@ -131,13 +131,6 @@ def hparam_flags_dict():
} }
def primary_cpu_task(use_remote_tpu=False):
"""Returns primary CPU task to which input pipeline Ops are put."""
# Remote Eager Borg job configures the TPU worker with job name 'worker'.
return '/job:worker' if use_remote_tpu else ''
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
"""Saves model to model_dir with provided checkpoint prefix.""" """Saves model to model_dir with provided checkpoint prefix."""
...@@ -204,7 +197,6 @@ class DistributedExecutor(object): ...@@ -204,7 +197,6 @@ class DistributedExecutor(object):
metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric. metric_fn: metric function. Signature: () -> tf.keras.metrics.Metric.
is_multi_host: Set to True when using multi hosts for training, like multi is_multi_host: Set to True when using multi hosts for training, like multi
worker GPU or TPU pod (slice). Otherwise, False. worker GPU or TPU pod (slice). Otherwise, False.
use_remote_tpu: If True, run on remote TPU mode.
""" """
def __init__(self, def __init__(self,
...@@ -212,14 +204,12 @@ class DistributedExecutor(object): ...@@ -212,14 +204,12 @@ class DistributedExecutor(object):
params, params,
model_fn, model_fn,
loss_fn, loss_fn,
is_multi_host=False, is_multi_host=False):
use_remote_tpu=False):
self._params = params self._params = params
self._model_fn = model_fn self._model_fn = model_fn
self._loss_fn = loss_fn self._loss_fn = loss_fn
self._strategy = strategy self._strategy = strategy
self._use_remote_tpu = use_remote_tpu
self._checkpoint_name = 'ctl_step_{step}.ckpt' self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host self._is_multi_host = is_multi_host
...@@ -867,11 +857,6 @@ class ExecutorBuilder(object): ...@@ -867,11 +857,6 @@ class ExecutorBuilder(object):
raise ValueError('`strategy` should not be None. You need to specify ' raise ValueError('`strategy` should not be None. You need to specify '
'`strategy_type` in the builder contructor or directly ' '`strategy_type` in the builder contructor or directly '
'set the `strategy` property of the builder.') 'set the `strategy` property of the builder.')
if 'use_remote_tpu' not in kwargs:
use_remote_tpu = (
isinstance(self._strategy, tf.distribute.experimental.TPUStrategy) and
bool(self._strategy_config.tpu))
kwargs['use_remote_tpu'] = use_remote_tpu
return class_ctor( return class_ctor(
strategy=self._strategy, strategy=self._strategy,
params=params, params=params,
......
...@@ -135,53 +135,50 @@ def get_raw_results(predictions): ...@@ -135,53 +135,50 @@ def get_raw_results(predictions):
def predict_squad_customized(strategy, input_meta_data, bert_config, def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps): predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model.""" """Make predictions using a Bert-based squad model."""
primary_cpu_task = '/job:worker' if FLAGS.tpu else '' predict_dataset = input_pipeline.create_squad_dataset(
predict_tfrecord_path,
with tf.device(primary_cpu_task): input_meta_data['max_seq_length'],
predict_dataset = input_pipeline.create_squad_dataset( FLAGS.predict_batch_size,
predict_tfrecord_path, is_training=False)
input_meta_data['max_seq_length'], predict_iterator = iter(
FLAGS.predict_batch_size, strategy.experimental_distribute_dataset(predict_dataset))
is_training=False)
predict_iterator = iter( with strategy.scope():
strategy.experimental_distribute_dataset(predict_dataset)) # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
with strategy.scope(): squad_model, _ = bert_models.squad_model(
# Prediction always uses float32, even if training uses mixed precision. bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model( checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) checkpoint.restore(checkpoint_path).expect_partial()
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model) @tf.function
checkpoint.restore(checkpoint_path).expect_partial() def predict_step(iterator):
"""Predicts on distributed devices."""
@tf.function
def predict_step(iterator): def _replicated_step(inputs):
"""Predicts on distributed devices.""" """Replicated prediction calculation."""
x, _ = inputs
def _replicated_step(inputs): unique_ids, start_logits, end_logits = squad_model(x, training=False)
"""Replicated prediction calculation.""" return dict(
x, _ = inputs unique_ids=unique_ids,
unique_ids, start_logits, end_logits = squad_model(x, training=False) start_logits=start_logits,
return dict( end_logits=end_logits)
unique_ids=unique_ids,
start_logits=start_logits, outputs = strategy.experimental_run_v2(
end_logits=end_logits) _replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),)) all_results = []
return tf.nest.map_structure(strategy.experimental_local_results, outputs) for _ in range(num_steps):
predictions = predict_step(predict_iterator)
all_results = [] for result in get_raw_results(predictions):
for _ in range(num_steps): all_results.append(result)
predictions = predict_step(predict_iterator) if len(all_results) % 100 == 0:
for result in get_raw_results(predictions): logging.info('Made predictions for %d records.', len(all_results))
all_results.append(result) return all_results
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy, def train_squad(strategy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment