Commit 7e810001 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Access TPUEstimator and CrossShardOptimizer from tf namesspace.

PiperOrigin-RevId: 192226678
parent b0c5c3b5
...@@ -22,8 +22,6 @@ import functools ...@@ -22,8 +22,6 @@ import functools
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from object_detection import eval_util from object_detection import eval_util
from object_detection import inputs from object_detection import inputs
from object_detection.builders import model_builder from object_detection.builders import model_builder
...@@ -291,7 +289,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False): ...@@ -291,7 +289,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
if use_tpu: if use_tpu:
training_optimizer = tpu_optimizer.CrossShardOptimizer( training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
training_optimizer) training_optimizer)
# Optionally freeze some layers by setting their gradients to be zero. # Optionally freeze some layers by setting their gradients to be zero.
...@@ -490,7 +488,7 @@ def create_estimator_and_inputs(run_config, ...@@ -490,7 +488,7 @@ def create_estimator_and_inputs(run_config,
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu) model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
if use_tpu_estimator: if use_tpu_estimator:
estimator = tpu_estimator.TPUEstimator( estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn, model_fn=model_fn,
train_batch_size=train_config.batch_size, train_batch_size=train_config.batch_size,
# For each core, only batch size 1 is supported for eval. # For each core, only batch size 1 is supported for eval.
......
...@@ -7,7 +7,9 @@ import "object_detection/protos/preprocessor.proto"; ...@@ -7,7 +7,9 @@ import "object_detection/protos/preprocessor.proto";
// Message for configuring DetectionModel training jobs (train.py). // Message for configuring DetectionModel training jobs (train.py).
message TrainConfig { message TrainConfig {
// Input queue batch size. // Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
// `batch_size` / number of cores (or `batch_size` / number of GPUs).
optional uint32 batch_size = 1 [default=32]; optional uint32 batch_size = 1 [default=32];
// Data augmentation options. // Data augmentation options.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment