Commit 227f41e9 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Returning eval_on_train_input_fn from create_estimator_and_inputs(), rather...

Returning eval_on_train_input_fn from create_estimator_and_inputs(), rather than using train_input_fn in EVAL mode (which will still have data augmentation).

PiperOrigin-RevId: 192320460
parent 7e810001
......@@ -437,6 +437,7 @@ def create_estimator_and_inputs(run_config,
'estimator': An `Estimator` or `TPUEstimator`.
'train_input_fn': A training input function.
'eval_input_fn': An evaluation input function.
'eval_on_train_input_fn': An evaluation-on-train input function.
'predict_input_fn': A prediction input function.
'train_steps': Number of training steps. Either directly from input or from
configuration.
......@@ -484,6 +485,10 @@ def create_estimator_and_inputs(run_config,
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config)
eval_on_train_input_fn = create_eval_input_fn(
eval_config=eval_config,
eval_input_config=train_input_config,
model_config=model_config)
predict_input_fn = create_predict_input_fn(model_config=model_config)
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
......@@ -509,6 +514,7 @@ def create_estimator_and_inputs(run_config,
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
eval_on_train_input_fn=eval_on_train_input_fn,
predict_input_fn=predict_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
......@@ -516,6 +522,7 @@ def create_estimator_and_inputs(run_config,
def create_train_and_eval_specs(train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
......@@ -527,6 +534,8 @@ def create_train_and_eval_specs(train_input_fn,
Args:
train_input_fn: Function that produces features and labels on train data.
eval_input_fn: Function that produces features and labels on eval data.
eval_on_train_input_fn: Function that produces features and labels for
evaluation on train data.
predict_input_fn: Function that produces features for inference.
train_steps: Number of training steps.
eval_steps: Number of eval steps.
......@@ -558,7 +567,8 @@ def create_train_and_eval_specs(train_input_fn,
if eval_on_train_data:
eval_specs.append(
tf.estimator.EvalSpec(
name='eval_on_train', input_fn=train_input_fn, steps=eval_steps))
name='eval_on_train', input_fn=eval_on_train_input_fn,
steps=eval_steps))
return train_spec, eval_specs
......
......@@ -83,17 +83,26 @@ class ModelLibTest(tf.test.TestCase):
model_config = configs['model']
train_config = configs['train_config']
with tf.Graph().as_default():
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == 'train':
features, labels = inputs.create_train_input_fn(
configs['train_config'],
configs['train_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.TRAIN
batch_size = train_config.batch_size
else:
elif mode == 'eval':
features, labels = inputs.create_eval_input_fn(
configs['eval_config'],
configs['eval_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.EVAL
batch_size = 1
elif mode == 'eval_on_train':
features, labels = inputs.create_eval_input_fn(
configs['eval_config'],
configs['train_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.EVAL
batch_size = 1
detection_model_fn = functools.partial(
......@@ -103,7 +112,7 @@ class ModelLibTest(tf.test.TestCase):
hparams_overrides='load_pretrained=false')
model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
estimator_spec = model_fn(features, labels, mode)
estimator_spec = model_fn(features, labels, model_mode)
self.assertIsNotNone(estimator_spec.loss)
self.assertIsNotNone(estimator_spec.predictions)
......@@ -121,7 +130,7 @@ class ModelLibTest(tf.test.TestCase):
self.assertEqual(batch_size, detection_scores.shape.as_list()[0])
self.assertEqual(tf.float32, detection_scores.dtype)
self.assertEqual(tf.float32, num_detections.dtype)
if mode == tf.estimator.ModeKeys.TRAIN:
if model_mode == tf.estimator.ModeKeys.TRAIN:
self.assertIsNotNone(estimator_spec.train_op)
return estimator_spec
......@@ -152,12 +161,17 @@ class ModelLibTest(tf.test.TestCase):
def test_model_fn_in_train_mode(self):
"""Tests the model function in TRAIN mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, tf.estimator.ModeKeys.TRAIN)
self._assert_model_fn_for_train_eval(configs, 'train')
def test_model_fn_in_eval_mode(self):
"""Tests the model function in EVAL mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, tf.estimator.ModeKeys.EVAL)
self._assert_model_fn_for_train_eval(configs, 'eval')
def test_model_fn_in_eval_on_train_mode(self):
"""Tests the model function in EVAL mode with train data."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, 'eval_on_train')
def test_model_fn_in_predict_mode(self):
"""Tests the model function in PREDICT mode."""
......@@ -181,10 +195,12 @@ class ModelLibTest(tf.test.TestCase):
estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']
self.assertIsInstance(estimator, tf.estimator.Estimator)
self.assertEqual(20, train_steps)
self.assertEqual(10, eval_steps)
self.assertIn('train_input_fn', train_and_eval_dict)
self.assertIn('eval_input_fn', train_and_eval_dict)
self.assertIn('eval_on_train_input_fn', train_and_eval_dict)
def test_create_estimator_with_default_train_eval_steps(self):
"""Tests that number of train/eval defaults to config values."""
......@@ -245,6 +261,7 @@ class ModelLibTest(tf.test.TestCase):
eval_steps=eval_steps)
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']
......@@ -252,6 +269,7 @@ class ModelLibTest(tf.test.TestCase):
train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
......
......@@ -54,6 +54,7 @@ def main(unused_argv):
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']
......@@ -61,6 +62,7 @@ def main(unused_argv):
train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
......
......@@ -130,6 +130,7 @@ def main(unused_argv):
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']
......@@ -158,7 +159,7 @@ def main(unused_argv):
tf.logging.info('Starting to evaluate.')
if FLAGS.eval_training_data:
name = 'training_data'
input_fn = train_input_fn
input_fn = eval_on_train_input_fn
else:
name = 'validation_data'
input_fn = eval_input_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment