Commit 5f05ce2d authored by Luke Wood's avatar Luke Wood Committed by TF Object Detection Team
Browse files

Explicitly import estimator from tensorflow as a separate import instead of...

Explicitly import estimator from tensorflow as a separate import instead of accessing it via tf.estimator and depend on the tensorflow estimator target.

PiperOrigin-RevId: 437338390
parent 7368d50e
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import functools import functools
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from object_detection.builders import dataset_builder from object_detection.builders import dataset_builder
from object_detection.builders import image_resizer_builder from object_detection.builders import image_resizer_builder
from object_detection.builders import model_builder from object_detection.builders import model_builder
...@@ -1114,7 +1115,7 @@ def create_predict_input_fn(model_config, predict_input_config): ...@@ -1114,7 +1115,7 @@ def create_predict_input_fn(model_config, predict_input_config):
true_image_shape = tf.expand_dims( true_image_shape = tf.expand_dims(
input_dict[fields.InputDataFields.true_image_shape], axis=0) input_dict[fields.InputDataFields.true_image_shape], axis=0)
return tf.estimator.export.ServingInputReceiver( return tf_estimator.export.ServingInputReceiver(
features={ features={
fields.InputDataFields.image: images, fields.InputDataFields.image: images,
fields.InputDataFields.true_image_shape: true_image_shape}, fields.InputDataFields.true_image_shape: true_image_shape},
......
...@@ -23,6 +23,7 @@ import functools ...@@ -23,6 +23,7 @@ import functools
import os import os
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow.compat.v2 as tf2 import tensorflow.compat.v2 as tf2
import tf_slim as slim import tf_slim as slim
...@@ -465,7 +466,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -465,7 +466,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
""" """
params = params or {} params = params or {}
total_loss, train_op, detections, export_outputs = None, None, None, None total_loss, train_op, detections, export_outputs = None, None, None, None
is_training = mode == tf.estimator.ModeKeys.TRAIN is_training = mode == tf_estimator.ModeKeys.TRAIN
# Make sure to set the Keras learning phase. True during training, # Make sure to set the Keras learning phase. True during training,
# False for inference. # False for inference.
...@@ -479,11 +480,11 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -479,11 +480,11 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
is_training=is_training, add_summaries=(not use_tpu)) is_training=is_training, add_summaries=(not use_tpu))
scaffold_fn = None scaffold_fn = None
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf_estimator.ModeKeys.TRAIN:
labels = unstack_batch( labels = unstack_batch(
labels, labels,
unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors) unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
elif mode == tf.estimator.ModeKeys.EVAL: elif mode == tf_estimator.ModeKeys.EVAL:
# For evaling on train data, it is necessary to check whether groundtruth # For evaling on train data, it is necessary to check whether groundtruth
# must be unpadded. # must be unpadded.
boxes_shape = ( boxes_shape = (
...@@ -493,7 +494,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -493,7 +494,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
labels = unstack_batch( labels = unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors) labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): if mode in (tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL):
provide_groundtruth(detection_model, labels) provide_groundtruth(detection_model, labels)
preprocessed_images = features[fields.InputDataFields.image] preprocessed_images = features[fields.InputDataFields.image]
...@@ -514,7 +515,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -514,7 +515,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
def postprocess_wrapper(args): def postprocess_wrapper(args):
return detection_model.postprocess(args[0], args[1]) return detection_model.postprocess(args[0], args[1])
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT): if mode in (tf_estimator.ModeKeys.EVAL, tf_estimator.ModeKeys.PREDICT):
if use_tpu and postprocess_on_cpu: if use_tpu and postprocess_on_cpu:
detections = tf.tpu.outside_compilation( detections = tf.tpu.outside_compilation(
postprocess_wrapper, postprocess_wrapper,
...@@ -525,7 +526,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -525,7 +526,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
prediction_dict, prediction_dict,
features[fields.InputDataFields.true_image_shape])) features[fields.InputDataFields.true_image_shape]))
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf_estimator.ModeKeys.TRAIN:
load_pretrained = hparams.load_pretrained if hparams else False load_pretrained = hparams.load_pretrained if hparams else False
if train_config.fine_tune_checkpoint and load_pretrained: if train_config.fine_tune_checkpoint and load_pretrained:
if not train_config.fine_tune_checkpoint_type: if not train_config.fine_tune_checkpoint_type:
...@@ -557,8 +558,8 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -557,8 +558,8 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint, tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
available_var_map) available_var_map)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL): if mode in (tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL):
if (mode == tf.estimator.ModeKeys.EVAL and if (mode == tf_estimator.ModeKeys.EVAL and
eval_config.use_dummy_loss_in_eval): eval_config.use_dummy_loss_in_eval):
total_loss = tf.constant(1.0) total_loss = tf.constant(1.0)
losses_dict = {'Loss/total_loss': total_loss} losses_dict = {'Loss/total_loss': total_loss}
...@@ -590,7 +591,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -590,7 +591,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
training_optimizer, optimizer_summary_vars = optimizer_builder.build( training_optimizer, optimizer_summary_vars = optimizer_builder.build(
train_config.optimizer) train_config.optimizer)
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf_estimator.ModeKeys.TRAIN:
if use_tpu: if use_tpu:
training_optimizer = tf.tpu.CrossShardOptimizer(training_optimizer) training_optimizer = tf.tpu.CrossShardOptimizer(training_optimizer)
...@@ -628,16 +629,16 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -628,16 +629,16 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
summaries=summaries, summaries=summaries,
name='') # Preventing scope prefix on all variables. name='') # Preventing scope prefix on all variables.
if mode == tf.estimator.ModeKeys.PREDICT: if mode == tf_estimator.ModeKeys.PREDICT:
exported_output = exporter_lib.add_output_tensor_nodes(detections) exported_output = exporter_lib.add_output_tensor_nodes(detections)
export_outputs = { export_outputs = {
tf.saved_model.signature_constants.PREDICT_METHOD_NAME: tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
tf.estimator.export.PredictOutput(exported_output) tf_estimator.export.PredictOutput(exported_output)
} }
eval_metric_ops = None eval_metric_ops = None
scaffold = None scaffold = None
if mode == tf.estimator.ModeKeys.EVAL: if mode == tf_estimator.ModeKeys.EVAL:
class_agnostic = ( class_agnostic = (
fields.DetectionResultFields.detection_classes not in detections) fields.DetectionResultFields.detection_classes not in detections)
groundtruth = _prepare_groundtruth_for_eval( groundtruth = _prepare_groundtruth_for_eval(
...@@ -711,8 +712,8 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -711,8 +712,8 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
scaffold = tf.train.Scaffold(saver=saver) scaffold = tf.train.Scaffold(saver=saver)
# EVAL executes on CPU, so use regular non-TPU EstimatorSpec. # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
if use_tpu and mode != tf.estimator.ModeKeys.EVAL: if use_tpu and mode != tf_estimator.ModeKeys.EVAL:
return tf.estimator.tpu.TPUEstimatorSpec( return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, mode=mode,
scaffold_fn=scaffold_fn, scaffold_fn=scaffold_fn,
predictions=detections, predictions=detections,
...@@ -730,7 +731,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False, ...@@ -730,7 +731,7 @@ def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
save_relative_paths=True) save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver) tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
scaffold = tf.train.Scaffold(saver=saver) scaffold = tf.train.Scaffold(saver=saver)
return tf.estimator.EstimatorSpec( return tf_estimator.EstimatorSpec(
mode=mode, mode=mode,
predictions=detections, predictions=detections,
loss=total_loss, loss=total_loss,
...@@ -895,7 +896,7 @@ def create_estimator_and_inputs(run_config, ...@@ -895,7 +896,7 @@ def create_estimator_and_inputs(run_config,
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu, model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
postprocess_on_cpu) postprocess_on_cpu)
if use_tpu_estimator: if use_tpu_estimator:
estimator = tf.estimator.tpu.TPUEstimator( estimator = tf_estimator.tpu.TPUEstimator(
model_fn=model_fn, model_fn=model_fn,
train_batch_size=train_config.batch_size, train_batch_size=train_config.batch_size,
# For each core, only batch size 1 is supported for eval. # For each core, only batch size 1 is supported for eval.
...@@ -906,7 +907,7 @@ def create_estimator_and_inputs(run_config, ...@@ -906,7 +907,7 @@ def create_estimator_and_inputs(run_config,
eval_on_tpu=False, # Eval runs on CPU, so disable eval on TPU eval_on_tpu=False, # Eval runs on CPU, so disable eval on TPU
params=params if params else {}) params=params if params else {})
else: else:
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) estimator = tf_estimator.Estimator(model_fn=model_fn, config=run_config)
# Write the as-run pipeline config to disk. # Write the as-run pipeline config to disk.
if run_config.is_chief and save_final_config: if run_config.is_chief and save_final_config:
...@@ -951,7 +952,7 @@ def create_train_and_eval_specs(train_input_fn, ...@@ -951,7 +952,7 @@ def create_train_and_eval_specs(train_input_fn,
True, the last `EvalSpec` in the list will correspond to training data. The True, the last `EvalSpec` in the list will correspond to training data. The
rest EvalSpecs in the list are evaluation datas. rest EvalSpecs in the list are evaluation datas.
""" """
train_spec = tf.estimator.TrainSpec( train_spec = tf_estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps) input_fn=train_input_fn, max_steps=train_steps)
if eval_spec_names is None: if eval_spec_names is None:
...@@ -966,10 +967,10 @@ def create_train_and_eval_specs(train_input_fn, ...@@ -966,10 +967,10 @@ def create_train_and_eval_specs(train_input_fn,
exporter_name = final_exporter_name exporter_name = final_exporter_name
else: else:
exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name) exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
exporter = tf.estimator.FinalExporter( exporter = tf_estimator.FinalExporter(
name=exporter_name, serving_input_receiver_fn=predict_input_fn) name=exporter_name, serving_input_receiver_fn=predict_input_fn)
eval_specs.append( eval_specs.append(
tf.estimator.EvalSpec( tf_estimator.EvalSpec(
name=eval_spec_name, name=eval_spec_name,
input_fn=eval_input_fn, input_fn=eval_input_fn,
steps=None, steps=None,
...@@ -977,7 +978,7 @@ def create_train_and_eval_specs(train_input_fn, ...@@ -977,7 +978,7 @@ def create_train_and_eval_specs(train_input_fn,
if eval_on_train_data: if eval_on_train_data:
eval_specs.append( eval_specs.append(
tf.estimator.EvalSpec( tf_estimator.EvalSpec(
name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None)) name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
return train_spec, eval_specs return train_spec, eval_specs
......
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import unittest import unittest
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from object_detection import inputs from object_detection import inputs
from object_detection import model_hparams from object_detection import model_hparams
...@@ -137,21 +138,21 @@ class ModelLibTest(tf.test.TestCase): ...@@ -137,21 +138,21 @@ class ModelLibTest(tf.test.TestCase):
inputs.create_train_input_fn(configs['train_config'], inputs.create_train_input_fn(configs['train_config'],
configs['train_input_config'], configs['train_input_config'],
configs['model'])()).get_next() configs['model'])()).get_next()
model_mode = tf.estimator.ModeKeys.TRAIN model_mode = tf_estimator.ModeKeys.TRAIN
batch_size = train_config.batch_size batch_size = train_config.batch_size
elif mode == 'eval': elif mode == 'eval':
features, labels = _make_initializable_iterator( features, labels = _make_initializable_iterator(
inputs.create_eval_input_fn(configs['eval_config'], inputs.create_eval_input_fn(configs['eval_config'],
configs['eval_input_config'], configs['eval_input_config'],
configs['model'])()).get_next() configs['model'])()).get_next()
model_mode = tf.estimator.ModeKeys.EVAL model_mode = tf_estimator.ModeKeys.EVAL
batch_size = 1 batch_size = 1
elif mode == 'eval_on_train': elif mode == 'eval_on_train':
features, labels = _make_initializable_iterator( features, labels = _make_initializable_iterator(
inputs.create_eval_input_fn(configs['eval_config'], inputs.create_eval_input_fn(configs['eval_config'],
configs['train_input_config'], configs['train_input_config'],
configs['model'])()).get_next() configs['model'])()).get_next()
model_mode = tf.estimator.ModeKeys.EVAL model_mode = tf_estimator.ModeKeys.EVAL
batch_size = 1 batch_size = 1
detection_model_fn = functools.partial( detection_model_fn = functools.partial(
...@@ -183,7 +184,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -183,7 +184,7 @@ class ModelLibTest(tf.test.TestCase):
if mode == 'eval': if mode == 'eval':
self.assertIn('Detections_Left_Groundtruth_Right/0', self.assertIn('Detections_Left_Groundtruth_Right/0',
estimator_spec.eval_metric_ops) estimator_spec.eval_metric_ops)
if model_mode == tf.estimator.ModeKeys.TRAIN: if model_mode == tf_estimator.ModeKeys.TRAIN:
self.assertIsNotNone(estimator_spec.train_op) self.assertIsNotNone(estimator_spec.train_op)
return estimator_spec return estimator_spec
...@@ -202,7 +203,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -202,7 +203,7 @@ class ModelLibTest(tf.test.TestCase):
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams) model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT) estimator_spec = model_fn(features, None, tf_estimator.ModeKeys.PREDICT)
self.assertIsNone(estimator_spec.loss) self.assertIsNone(estimator_spec.loss)
self.assertIsNone(estimator_spec.train_op) self.assertIsNone(estimator_spec.train_op)
...@@ -279,7 +280,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -279,7 +280,7 @@ class ModelLibTest(tf.test.TestCase):
def test_create_estimator_and_inputs(self): def test_create_estimator_and_inputs(self):
"""Tests that Estimator and input function are constructed correctly.""" """Tests that Estimator and input function are constructed correctly."""
run_config = tf.estimator.RunConfig() run_config = tf_estimator.RunConfig()
hparams = model_hparams.create_hparams( hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
...@@ -291,7 +292,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -291,7 +292,7 @@ class ModelLibTest(tf.test.TestCase):
train_steps=train_steps) train_steps=train_steps)
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
self.assertIsInstance(estimator, tf.estimator.Estimator) self.assertIsInstance(estimator, tf_estimator.Estimator)
self.assertEqual(20, train_steps) self.assertEqual(20, train_steps)
self.assertIn('train_input_fn', train_and_eval_dict) self.assertIn('train_input_fn', train_and_eval_dict)
self.assertIn('eval_input_fns', train_and_eval_dict) self.assertIn('eval_input_fns', train_and_eval_dict)
...@@ -299,7 +300,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -299,7 +300,7 @@ class ModelLibTest(tf.test.TestCase):
def test_create_estimator_and_inputs_sequence_example(self): def test_create_estimator_and_inputs_sequence_example(self):
"""Tests that Estimator and input function are constructed correctly.""" """Tests that Estimator and input function are constructed correctly."""
run_config = tf.estimator.RunConfig() run_config = tf_estimator.RunConfig()
hparams = model_hparams.create_hparams( hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path( pipeline_config_path = get_pipeline_config_path(
...@@ -312,7 +313,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -312,7 +313,7 @@ class ModelLibTest(tf.test.TestCase):
train_steps=train_steps) train_steps=train_steps)
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
self.assertIsInstance(estimator, tf.estimator.Estimator) self.assertIsInstance(estimator, tf_estimator.Estimator)
self.assertEqual(20, train_steps) self.assertEqual(20, train_steps)
self.assertIn('train_input_fn', train_and_eval_dict) self.assertIn('train_input_fn', train_and_eval_dict)
self.assertIn('eval_input_fns', train_and_eval_dict) self.assertIn('eval_input_fns', train_and_eval_dict)
...@@ -320,7 +321,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -320,7 +321,7 @@ class ModelLibTest(tf.test.TestCase):
def test_create_estimator_with_default_train_eval_steps(self): def test_create_estimator_with_default_train_eval_steps(self):
"""Tests that number of train/eval defaults to config values.""" """Tests that number of train/eval defaults to config values."""
run_config = tf.estimator.RunConfig() run_config = tf_estimator.RunConfig()
hparams = model_hparams.create_hparams( hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
...@@ -331,12 +332,12 @@ class ModelLibTest(tf.test.TestCase): ...@@ -331,12 +332,12 @@ class ModelLibTest(tf.test.TestCase):
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
self.assertIsInstance(estimator, tf.estimator.Estimator) self.assertIsInstance(estimator, tf_estimator.Estimator)
self.assertEqual(config_train_steps, train_steps) self.assertEqual(config_train_steps, train_steps)
def test_create_tpu_estimator_and_inputs(self): def test_create_tpu_estimator_and_inputs(self):
"""Tests that number of train/eval defaults to config values.""" """Tests that number of train/eval defaults to config values."""
run_config = tf.estimator.tpu.RunConfig() run_config = tf_estimator.tpu.RunConfig()
hparams = model_hparams.create_hparams( hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
...@@ -350,12 +351,12 @@ class ModelLibTest(tf.test.TestCase): ...@@ -350,12 +351,12 @@ class ModelLibTest(tf.test.TestCase):
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
self.assertIsInstance(estimator, tf.estimator.tpu.TPUEstimator) self.assertIsInstance(estimator, tf_estimator.tpu.TPUEstimator)
self.assertEqual(20, train_steps) self.assertEqual(20, train_steps)
def test_create_train_and_eval_specs(self): def test_create_train_and_eval_specs(self):
"""Tests that `TrainSpec` and `EvalSpec` is created correctly.""" """Tests that `TrainSpec` and `EvalSpec` is created correctly."""
run_config = tf.estimator.RunConfig() run_config = tf_estimator.RunConfig()
hparams = model_hparams.create_hparams( hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
...@@ -390,7 +391,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -390,7 +391,7 @@ class ModelLibTest(tf.test.TestCase):
def test_experiment(self): def test_experiment(self):
"""Tests that the `Experiment` object is constructed correctly.""" """Tests that the `Experiment` object is constructed correctly."""
run_config = tf.estimator.RunConfig() run_config = tf_estimator.RunConfig()
hparams = model_hparams.create_hparams( hparams = model_hparams.create_hparams(
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
from absl import flags from absl import flags
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from object_detection import model_lib from object_detection import model_lib
...@@ -59,7 +60,7 @@ FLAGS = flags.FLAGS ...@@ -59,7 +60,7 @@ FLAGS = flags.FLAGS
def main(unused_argv): def main(unused_argv):
flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir) config = tf_estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs( train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config, run_config=config,
...@@ -101,7 +102,7 @@ def main(unused_argv): ...@@ -101,7 +102,7 @@ def main(unused_argv):
eval_on_train_data=False) eval_on_train_data=False)
# Currently only a single Eval Spec is allowed. # Currently only a single Eval Spec is allowed.
tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0]) tf_estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -24,6 +24,7 @@ from __future__ import print_function ...@@ -24,6 +24,7 @@ from __future__ import print_function
from absl import flags from absl import flags
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from object_detection import model_lib from object_detection import model_lib
...@@ -89,11 +90,11 @@ def main(unused_argv): ...@@ -89,11 +90,11 @@ def main(unused_argv):
tpu=[FLAGS.tpu_name], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) tpu=[FLAGS.tpu_name], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
tpu_grpc_url = tpu_cluster_resolver.get_master() tpu_grpc_url = tpu_cluster_resolver.get_master()
config = tf.estimator.tpu.RunConfig( config = tf_estimator.tpu.RunConfig(
master=tpu_grpc_url, master=tpu_grpc_url,
evaluation_master=tpu_grpc_url, evaluation_master=tpu_grpc_url,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
tpu_config=tf.estimator.tpu.TPUConfig( tpu_config=tf_estimator.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop, iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_shards)) num_shards=FLAGS.num_shards))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment