Commit 89def413 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix tpu training issues

parent aca51294
...@@ -330,7 +330,7 @@ def load_fine_tune_checkpoint( ...@@ -330,7 +330,7 @@ def load_fine_tune_checkpoint(
labels) labels)
strategy = tf.compat.v2.distribute.get_strategy() strategy = tf.compat.v2.distribute.get_strategy()
strategy.run( strategy.experimental_run_v2(
_dummy_computation_fn, args=( _dummy_computation_fn, args=(
features, features,
labels, labels,
...@@ -570,7 +570,7 @@ def train_loop( ...@@ -570,7 +570,7 @@ def train_loop(
def _sample_and_train(strategy, train_step_fn, data_iterator): def _sample_and_train(strategy, train_step_fn, data_iterator):
features, labels = data_iterator.next() features, labels = data_iterator.next()
per_replica_losses = strategy.run( per_replica_losses = strategy.experimental_run_v2(
train_step_fn, args=(features, labels)) train_step_fn, args=(features, labels))
# TODO(anjalisridhar): explore if it is safe to remove the # TODO(anjalisridhar): explore if it is safe to remove the
## num_replicas scaling of the loss and switch this to a ReduceOp.Mean ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
......
...@@ -42,6 +42,7 @@ from object_detection import model_lib_v2 ...@@ -42,6 +42,7 @@ from object_detection import model_lib_v2
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config ' flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.') 'file.')
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.') flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
flags.DEFINE_bool('use_tpu', False, 'Whether to use TPUs')
flags.DEFINE_bool('eval_on_train_data', False, 'Enable evaluating on train ' flags.DEFINE_bool('eval_on_train_data', False, 'Enable evaluating on train '
'data (only supported in distributed training).') 'data (only supported in distributed training).')
flags.DEFINE_integer('sample_1_of_n_eval_examples', None, 'Will sample one of ' flags.DEFINE_integer('sample_1_of_n_eval_examples', None, 'Will sample one of '
...@@ -84,7 +85,7 @@ def main(unused_argv): ...@@ -84,7 +85,7 @@ def main(unused_argv):
checkpoint_dir=FLAGS.checkpoint_dir, checkpoint_dir=FLAGS.checkpoint_dir,
wait_interval=300, timeout=FLAGS.eval_timeout) wait_interval=300, timeout=FLAGS.eval_timeout)
else: else:
if tf.config.get_visible_devices('TPU'): if FLAGS.use_tpu:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver() resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver) tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver) tf.tpu.experimental.initialize_tpu_system(resolver)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment