"src/graph/vscode:/vscode.git/clone" did not exist on "c51cc82e4b4ae68fb2444b6d9ad57ce574e79bb0"
Commit 227f41e9 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Returning eval_on_train_input_fn from create_estimator_and_inputs(), rather...

Returning eval_on_train_input_fn from create_estimator_and_inputs(), rather than using train_input_fn in EVAL mode (which will still have data augmentation).

PiperOrigin-RevId: 192320460
parent 7e810001
...@@ -437,6 +437,7 @@ def create_estimator_and_inputs(run_config, ...@@ -437,6 +437,7 @@ def create_estimator_and_inputs(run_config,
'estimator': An `Estimator` or `TPUEstimator`. 'estimator': An `Estimator` or `TPUEstimator`.
'train_input_fn': A training input function. 'train_input_fn': A training input function.
'eval_input_fn': An evaluation input function. 'eval_input_fn': An evaluation input function.
'eval_on_train_input_fn': An evaluation-on-train input function.
'predict_input_fn': A prediction input function. 'predict_input_fn': A prediction input function.
'train_steps': Number of training steps. Either directly from input or from 'train_steps': Number of training steps. Either directly from input or from
configuration. configuration.
...@@ -484,6 +485,10 @@ def create_estimator_and_inputs(run_config, ...@@ -484,6 +485,10 @@ def create_estimator_and_inputs(run_config,
eval_config=eval_config, eval_config=eval_config,
eval_input_config=eval_input_config, eval_input_config=eval_input_config,
model_config=model_config) model_config=model_config)
eval_on_train_input_fn = create_eval_input_fn(
eval_config=eval_config,
eval_input_config=train_input_config,
model_config=model_config)
predict_input_fn = create_predict_input_fn(model_config=model_config) predict_input_fn = create_predict_input_fn(model_config=model_config)
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu) model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
...@@ -509,6 +514,7 @@ def create_estimator_and_inputs(run_config, ...@@ -509,6 +514,7 @@ def create_estimator_and_inputs(run_config,
estimator=estimator, estimator=estimator,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
eval_on_train_input_fn=eval_on_train_input_fn,
predict_input_fn=predict_input_fn, predict_input_fn=predict_input_fn,
train_steps=train_steps, train_steps=train_steps,
eval_steps=eval_steps) eval_steps=eval_steps)
...@@ -516,6 +522,7 @@ def create_estimator_and_inputs(run_config, ...@@ -516,6 +522,7 @@ def create_estimator_and_inputs(run_config,
def create_train_and_eval_specs(train_input_fn, def create_train_and_eval_specs(train_input_fn,
eval_input_fn, eval_input_fn,
eval_on_train_input_fn,
predict_input_fn, predict_input_fn,
train_steps, train_steps,
eval_steps, eval_steps,
...@@ -527,6 +534,8 @@ def create_train_and_eval_specs(train_input_fn, ...@@ -527,6 +534,8 @@ def create_train_and_eval_specs(train_input_fn,
Args: Args:
train_input_fn: Function that produces features and labels on train data. train_input_fn: Function that produces features and labels on train data.
eval_input_fn: Function that produces features and labels on eval data. eval_input_fn: Function that produces features and labels on eval data.
eval_on_train_input_fn: Function that produces features and labels for
evaluation on train data.
predict_input_fn: Function that produces features for inference. predict_input_fn: Function that produces features for inference.
train_steps: Number of training steps. train_steps: Number of training steps.
eval_steps: Number of eval steps. eval_steps: Number of eval steps.
...@@ -558,7 +567,8 @@ def create_train_and_eval_specs(train_input_fn, ...@@ -558,7 +567,8 @@ def create_train_and_eval_specs(train_input_fn,
if eval_on_train_data: if eval_on_train_data:
eval_specs.append( eval_specs.append(
tf.estimator.EvalSpec( tf.estimator.EvalSpec(
name='eval_on_train', input_fn=train_input_fn, steps=eval_steps)) name='eval_on_train', input_fn=eval_on_train_input_fn,
steps=eval_steps))
return train_spec, eval_specs return train_spec, eval_specs
......
...@@ -83,17 +83,26 @@ class ModelLibTest(tf.test.TestCase): ...@@ -83,17 +83,26 @@ class ModelLibTest(tf.test.TestCase):
model_config = configs['model'] model_config = configs['model']
train_config = configs['train_config'] train_config = configs['train_config']
with tf.Graph().as_default(): with tf.Graph().as_default():
if mode == tf.estimator.ModeKeys.TRAIN: if mode == 'train':
features, labels = inputs.create_train_input_fn( features, labels = inputs.create_train_input_fn(
configs['train_config'], configs['train_config'],
configs['train_input_config'], configs['train_input_config'],
configs['model'])() configs['model'])()
model_mode = tf.estimator.ModeKeys.TRAIN
batch_size = train_config.batch_size batch_size = train_config.batch_size
else: elif mode == 'eval':
features, labels = inputs.create_eval_input_fn( features, labels = inputs.create_eval_input_fn(
configs['eval_config'], configs['eval_config'],
configs['eval_input_config'], configs['eval_input_config'],
configs['model'])() configs['model'])()
model_mode = tf.estimator.ModeKeys.EVAL
batch_size = 1
elif mode == 'eval_on_train':
features, labels = inputs.create_eval_input_fn(
configs['eval_config'],
configs['train_input_config'],
configs['model'])()
model_mode = tf.estimator.ModeKeys.EVAL
batch_size = 1 batch_size = 1
detection_model_fn = functools.partial( detection_model_fn = functools.partial(
...@@ -103,7 +112,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -103,7 +112,7 @@ class ModelLibTest(tf.test.TestCase):
hparams_overrides='load_pretrained=false') hparams_overrides='load_pretrained=false')
model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams) model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
estimator_spec = model_fn(features, labels, mode) estimator_spec = model_fn(features, labels, model_mode)
self.assertIsNotNone(estimator_spec.loss) self.assertIsNotNone(estimator_spec.loss)
self.assertIsNotNone(estimator_spec.predictions) self.assertIsNotNone(estimator_spec.predictions)
...@@ -121,7 +130,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -121,7 +130,7 @@ class ModelLibTest(tf.test.TestCase):
self.assertEqual(batch_size, detection_scores.shape.as_list()[0]) self.assertEqual(batch_size, detection_scores.shape.as_list()[0])
self.assertEqual(tf.float32, detection_scores.dtype) self.assertEqual(tf.float32, detection_scores.dtype)
self.assertEqual(tf.float32, num_detections.dtype) self.assertEqual(tf.float32, num_detections.dtype)
if mode == tf.estimator.ModeKeys.TRAIN: if model_mode == tf.estimator.ModeKeys.TRAIN:
self.assertIsNotNone(estimator_spec.train_op) self.assertIsNotNone(estimator_spec.train_op)
return estimator_spec return estimator_spec
...@@ -152,12 +161,17 @@ class ModelLibTest(tf.test.TestCase): ...@@ -152,12 +161,17 @@ class ModelLibTest(tf.test.TestCase):
def test_model_fn_in_train_mode(self): def test_model_fn_in_train_mode(self):
"""Tests the model function in TRAIN mode.""" """Tests the model function in TRAIN mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST) configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, tf.estimator.ModeKeys.TRAIN) self._assert_model_fn_for_train_eval(configs, 'train')
def test_model_fn_in_eval_mode(self): def test_model_fn_in_eval_mode(self):
"""Tests the model function in EVAL mode.""" """Tests the model function in EVAL mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST) configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, tf.estimator.ModeKeys.EVAL) self._assert_model_fn_for_train_eval(configs, 'eval')
def test_model_fn_in_eval_on_train_mode(self):
"""Tests the model function in EVAL mode with train data."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, 'eval_on_train')
def test_model_fn_in_predict_mode(self): def test_model_fn_in_predict_mode(self):
"""Tests the model function in PREDICT mode.""" """Tests the model function in PREDICT mode."""
...@@ -181,10 +195,12 @@ class ModelLibTest(tf.test.TestCase): ...@@ -181,10 +195,12 @@ class ModelLibTest(tf.test.TestCase):
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps'] eval_steps = train_and_eval_dict['eval_steps']
self.assertIsInstance(estimator, tf.estimator.Estimator) self.assertIsInstance(estimator, tf.estimator.Estimator)
self.assertEqual(20, train_steps) self.assertEqual(20, train_steps)
self.assertEqual(10, eval_steps) self.assertEqual(10, eval_steps)
self.assertIn('train_input_fn', train_and_eval_dict)
self.assertIn('eval_input_fn', train_and_eval_dict)
self.assertIn('eval_on_train_input_fn', train_and_eval_dict)
def test_create_estimator_with_default_train_eval_steps(self): def test_create_estimator_with_default_train_eval_steps(self):
"""Tests that number of train/eval defaults to config values.""" """Tests that number of train/eval defaults to config values."""
...@@ -245,6 +261,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -245,6 +261,7 @@ class ModelLibTest(tf.test.TestCase):
eval_steps=eval_steps) eval_steps=eval_steps)
train_input_fn = train_and_eval_dict['train_input_fn'] train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn'] eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps'] eval_steps = train_and_eval_dict['eval_steps']
...@@ -252,6 +269,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -252,6 +269,7 @@ class ModelLibTest(tf.test.TestCase):
train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
eval_on_train_input_fn,
predict_input_fn, predict_input_fn,
train_steps, train_steps,
eval_steps, eval_steps,
......
...@@ -54,6 +54,7 @@ def main(unused_argv): ...@@ -54,6 +54,7 @@ def main(unused_argv):
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn'] train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn'] eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps'] eval_steps = train_and_eval_dict['eval_steps']
...@@ -61,6 +62,7 @@ def main(unused_argv): ...@@ -61,6 +62,7 @@ def main(unused_argv):
train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn, train_input_fn,
eval_input_fn, eval_input_fn,
eval_on_train_input_fn,
predict_input_fn, predict_input_fn,
train_steps, train_steps,
eval_steps, eval_steps,
......
...@@ -130,6 +130,7 @@ def main(unused_argv): ...@@ -130,6 +130,7 @@ def main(unused_argv):
estimator = train_and_eval_dict['estimator'] estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn'] train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn'] eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
train_steps = train_and_eval_dict['train_steps'] train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps'] eval_steps = train_and_eval_dict['eval_steps']
...@@ -158,7 +159,7 @@ def main(unused_argv): ...@@ -158,7 +159,7 @@ def main(unused_argv):
tf.logging.info('Starting to evaluate.') tf.logging.info('Starting to evaluate.')
if FLAGS.eval_training_data: if FLAGS.eval_training_data:
name = 'training_data' name = 'training_data'
input_fn = train_input_fn input_fn = eval_on_train_input_fn
else: else:
name = 'validation_data' name = 'validation_data'
input_fn = eval_input_fn input_fn = eval_input_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment