Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8793267f
Commit
8793267f
authored
Jul 18, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Support evaluating over multiple datasets.
PiperOrigin-RevId: 205168785
parent
75d592e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
42 deletions
+57
-42
research/astronet/README.md
research/astronet/README.md
+2
-2
research/astronet/astronet/train.py
research/astronet/astronet/train.py
+4
-1
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+51
-39
No files found.
research/astronet/README.md
View file @
8793267f
...
...
@@ -207,7 +207,7 @@ the second deepest transits).
To train a model to identify exoplanets, you will need to provide TensorFlow
with training data in
[
TFRecord
](
https://www.tensorflow.org/guide/datasets
)
format. The
[
TFRecord
](
https://www.tensorflow.org/
programmers_
guide/datasets
)
format. The
TFRecord format consists of a set of sharded files containing serialized
`tf.Example`
[
protocol buffers
](
https://developers.google.com/protocol-buffers/
)
.
...
...
@@ -343,7 +343,7 @@ bazel-bin/astronet/train \
--model_dir
=
${
MODEL_DIR
}
```
Optionally, you can also run a
[
TensorBoard
](
https://www.tensorflow.org/guide/summaries_and_tensorboard
)
Optionally, you can also run a
[
TensorBoard
](
https://www.tensorflow.org/
programmers_
guide/summaries_and_tensorboard
)
server in a separate process for real-time
monitoring of training progress and evaluation metrics.
...
...
research/astronet/astronet/train.py
View file @
8793267f
...
...
@@ -112,11 +112,14 @@ def main(_):
file_pattern
=
FLAGS
.
eval_files
,
input_config
=
config
.
inputs
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
)
eval_args
=
{
"val"
:
(
eval_input_fn
,
None
)
# eval_name: (input_fn, eval_steps)
}
for
_
in
estimator_util
.
continuous_train_and_eval
(
estimator
=
estimator
,
train_input_fn
=
train_input_fn
,
eval_
input_fn
=
eval_input_fn
,
eval_
args
=
eval_args
,
train_steps
=
FLAGS
.
train_steps
):
# continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here.
...
...
research/astronet/astronet/util/estimator_util.py
View file @
8793267f
...
...
@@ -204,94 +204,106 @@ def create_estimator(model_class,
return
estimator
def
evaluate
(
estimator
,
input_fn
,
eval_steps
=
None
,
eval_name
=
"val"
):
def
evaluate
(
estimator
,
eval_args
):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
values
=
{}
# Default return value if evaluation fails.
# Default return values if evaluation fails.
global_step
=
None
values
=
{}
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
estimator
.
model_dir
)
if
not
latest_checkpoint
:
# This is expected if the training job has not yet saved a checkpoint.
return
values
return
global_step
,
values
tf
.
logging
.
info
(
"Starting evaluation on checkpoint %s"
,
latest_checkpoint
)
try
:
values
=
estimator
.
evaluate
(
input_fn
,
steps
=
eval_steps
,
name
=
eval_name
)
for
eval_name
,
(
input_fn
,
eval_steps
)
in
eval_args
.
items
():
values
[
eval_name
]
=
estimator
.
evaluate
(
input_fn
,
steps
=
eval_steps
,
name
=
eval_name
)
if
global_step
is
None
:
global_step
=
values
[
eval_name
].
get
(
"global_step"
)
except
tf
.
errors
.
NotFoundError
:
# Expected under some conditions, e.g.
TPU worker does not finish
#
initializ
ing un
til long after the CPU job tells it to start evaluating
#
and the checkpoint file is deleted already
.
# Expected under some conditions, e.g.
checkpoint is already deleted by the
#
trainer process. Increas
ing
R
un
Config.keep_checkpoint_max may prevent this
#
in some cases
.
tf
.
logging
.
info
(
"Checkpoint %s no longer exists, skipping evaluation"
,
latest_checkpoint
)
return
values
return
global_step
,
values
def
continuous_eval
(
estimator
,
input_fn
,
eval_args
,
train_steps
=
None
,
eval_step
s
=
None
,
eval_name
=
"val"
):
timeout_sec
s
=
None
,
timeout_fn
=
None
):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
If None, this
function will run forever.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evalua
tion
s
et
, e.g. "train" or "val"
.
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the func
tion
r
et
urns True
.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for
_
in
tf
.
contrib
.
training
.
checkpoints_iterator
(
estimator
.
model_dir
):
values
=
evaluate
(
estimator
,
input_fn
,
eval_steps
,
eval_name
)
yield
values
for
_
in
tf
.
contrib
.
training
.
checkpoints_iterator
(
estimator
.
model_dir
,
timeout
=
timeout_secs
,
timeout_fn
=
timeout_fn
):
global_step
,
values
=
evaluate
(
estimator
,
eval_args
)
yield
global_step
,
values
global_step
=
values
.
get
(
"global_step"
,
0
)
global_step
=
global_step
or
0
# Ensure global_step is not None.
if
train_steps
and
global_step
>=
train_steps
:
break
def
continuous_train_and_eval
(
estimator
,
train_input_fn
,
eval_
input_fn
,
eval_
args
,
local_eval_frequency
=
None
,
train_hooks
=
None
,
train_steps
=
None
,
eval_steps
=
None
,
eval_name
=
"val"
):
train_steps
=
None
):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
...
...
@@ -301,10 +313,10 @@ def continuous_train_and_eval(estimator,
while
True
:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
values
=
evaluate
(
estimator
,
eval_
input_fn
,
eval_steps
,
eval_name
)
yield
values
global_step
,
values
=
evaluate
(
estimator
,
eval_
args
)
yield
global_step
,
values
global_step
=
values
.
get
(
"global_step"
,
0
)
global_step
=
global_step
or
0
# Ensure global_step is not None.
if
train_steps
and
global_step
>=
train_steps
:
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment