Commit 7bf81db8 authored by Jose Baiocchi's avatar Jose Baiocchi Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300399639
parent 1fdfd973
......@@ -28,7 +28,7 @@ import sys
# pylint: disable=g-bad-import-order
from absl import app as absl_app # pylint: disable=unused-import
import tensorflow as tf
import tensorflow.compat.v1 as tf
# pylint: enable=g-bad-import-order
# For open source environment, add grandparent directory for import
......@@ -98,7 +98,7 @@ def model_fn(features, labels, mode, params):
'class_ids': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(mode, predictions=predictions)
return tf.estimator.tpu.TPUEstimatorSpec(mode, predictions=predictions)
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
......@@ -111,14 +111,14 @@ def model_fn(features, labels, mode, params):
decay_rate=0.96)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
if FLAGS.use_tpu:
optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
......@@ -128,7 +128,7 @@ def train_input_fn(params):
data_dir = params["data_dir"]
# Retrieves the batch size for the current shard. The # of shards is
# computed according to the input pipeline deployment. See
# `tf.compat.v1.estimator.tpu.RunConfig` for details.
# `tf.estimator.tpu.RunConfig` for details.
ds = dataset.train(data_dir).cache().repeat().shuffle(
buffer_size=50000).batch(batch_size, drop_remainder=True)
return ds
......@@ -159,16 +159,15 @@ def main(argv):
project=FLAGS.gcp_project
)
run_config = tf.compat.v1.estimator.tpu.RunConfig(
run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
FLAGS.iterations, FLAGS.num_shards),
tpu_config=tf.estimator.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
)
estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
estimator = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
train_batch_size=FLAGS.batch_size,
......@@ -199,4 +198,5 @@ def main(argv):
if __name__ == "__main__":
tf.disable_v2_behavior()
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment