Commit 7bf81db8 authored by Jose Baiocchi's avatar Jose Baiocchi Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300399639
parent 1fdfd973
...@@ -28,7 +28,7 @@ import sys ...@@ -28,7 +28,7 @@ import sys
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app as absl_app # pylint: disable=unused-import from absl import app as absl_app # pylint: disable=unused-import
import tensorflow as tf import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
# For open source environment, add grandparent directory for import # For open source environment, add grandparent directory for import
...@@ -98,7 +98,7 @@ def model_fn(features, labels, mode, params): ...@@ -98,7 +98,7 @@ def model_fn(features, labels, mode, params):
'class_ids': tf.argmax(logits, axis=1), 'class_ids': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits), 'probabilities': tf.nn.softmax(logits),
} }
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(mode, predictions=predictions) return tf.estimator.tpu.TPUEstimatorSpec(mode, predictions=predictions)
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
...@@ -111,14 +111,14 @@ def model_fn(features, labels, mode, params): ...@@ -111,14 +111,14 @@ def model_fn(features, labels, mode, params):
decay_rate=0.96) decay_rate=0.96)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
if FLAGS.use_tpu: if FLAGS.use_tpu:
optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) optimizer = tf.tpu.CrossShardOptimizer(optimizer)
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec( return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode, mode=mode,
loss=loss, loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_global_step())) train_op=optimizer.minimize(loss, tf.train.get_global_step()))
if mode == tf.estimator.ModeKeys.EVAL: if mode == tf.estimator.ModeKeys.EVAL:
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec( return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
...@@ -128,7 +128,7 @@ def train_input_fn(params): ...@@ -128,7 +128,7 @@ def train_input_fn(params):
data_dir = params["data_dir"] data_dir = params["data_dir"]
# Retrieves the batch size for the current shard. The # of shards is # Retrieves the batch size for the current shard. The # of shards is
# computed according to the input pipeline deployment. See # computed according to the input pipeline deployment. See
# `tf.compat.v1.estimator.tpu.RunConfig` for details. # `tf.estimator.tpu.RunConfig` for details.
ds = dataset.train(data_dir).cache().repeat().shuffle( ds = dataset.train(data_dir).cache().repeat().shuffle(
buffer_size=50000).batch(batch_size, drop_remainder=True) buffer_size=50000).batch(batch_size, drop_remainder=True)
return ds return ds
...@@ -159,16 +159,15 @@ def main(argv): ...@@ -159,16 +159,15 @@ def main(argv):
project=FLAGS.gcp_project project=FLAGS.gcp_project
) )
run_config = tf.compat.v1.estimator.tpu.RunConfig( run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver, cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto( session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True), allow_soft_placement=True, log_device_placement=True),
tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( tpu_config=tf.estimator.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
FLAGS.iterations, FLAGS.num_shards),
) )
estimator = tf.compat.v1.estimator.tpu.TPUEstimator( estimator = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn, model_fn=model_fn,
use_tpu=FLAGS.use_tpu, use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size, train_batch_size=FLAGS.batch_size,
...@@ -199,4 +198,5 @@ def main(argv): ...@@ -199,4 +198,5 @@ def main(argv):
if __name__ == "__main__": if __name__ == "__main__":
tf.disable_v2_behavior()
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment