Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8793267f
Commit
8793267f
authored
Jul 18, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Support evaluating over multiple datasets.
PiperOrigin-RevId: 205168785
parent
75d592e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
42 deletions
+57
-42
research/astronet/README.md
research/astronet/README.md
+2
-2
research/astronet/astronet/train.py
research/astronet/astronet/train.py
+4
-1
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+51
-39
No files found.
research/astronet/README.md
View file @
8793267f
...
@@ -207,7 +207,7 @@ the second deepest transits).
...
@@ -207,7 +207,7 @@ the second deepest transits).
To train a model to identify exoplanets, you will need to provide TensorFlow
To train a model to identify exoplanets, you will need to provide TensorFlow
with training data in
with training data in
[
TFRecord
](
https://www.tensorflow.org/guide/datasets
)
format. The
[
TFRecord
](
https://www.tensorflow.org/
programmers_
guide/datasets
)
format. The
TFRecord format consists of a set of sharded files containing serialized
TFRecord format consists of a set of sharded files containing serialized
`tf.Example`
[
protocol buffers
](
https://developers.google.com/protocol-buffers/
)
.
`tf.Example`
[
protocol buffers
](
https://developers.google.com/protocol-buffers/
)
.
...
@@ -343,7 +343,7 @@ bazel-bin/astronet/train \
...
@@ -343,7 +343,7 @@ bazel-bin/astronet/train \
--model_dir
=
${
MODEL_DIR
}
--model_dir
=
${
MODEL_DIR
}
```
```
Optionally, you can also run a
[
TensorBoard
](
https://www.tensorflow.org/guide/summaries_and_tensorboard
)
Optionally, you can also run a
[
TensorBoard
](
https://www.tensorflow.org/
programmers_
guide/summaries_and_tensorboard
)
server in a separate process for real-time
server in a separate process for real-time
monitoring of training progress and evaluation metrics.
monitoring of training progress and evaluation metrics.
...
...
research/astronet/astronet/train.py
View file @
8793267f
...
@@ -112,11 +112,14 @@ def main(_):
...
@@ -112,11 +112,14 @@ def main(_):
file_pattern
=
FLAGS
.
eval_files
,
file_pattern
=
FLAGS
.
eval_files
,
input_config
=
config
.
inputs
,
input_config
=
config
.
inputs
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
)
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
)
eval_args
=
{
"val"
:
(
eval_input_fn
,
None
)
# eval_name: (input_fn, eval_steps)
}
for
_
in
estimator_util
.
continuous_train_and_eval
(
for
_
in
estimator_util
.
continuous_train_and_eval
(
estimator
=
estimator
,
estimator
=
estimator
,
train_input_fn
=
train_input_fn
,
train_input_fn
=
train_input_fn
,
eval_
input_fn
=
eval_input_fn
,
eval_
args
=
eval_args
,
train_steps
=
FLAGS
.
train_steps
):
train_steps
=
FLAGS
.
train_steps
):
# continuous_train_and_eval() yields evaluation metrics after each
# continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here.
# training epoch. We don't do anything here.
...
...
research/astronet/astronet/util/estimator_util.py
View file @
8793267f
...
@@ -204,94 +204,106 @@ def create_estimator(model_class,
...
@@ -204,94 +204,106 @@ def create_estimator(model_class,
return
estimator
return
estimator
def
evaluate
(
estimator
,
input_fn
,
eval_steps
=
None
,
eval_name
=
"val"
):
def
evaluate
(
estimator
,
eval_args
):
"""Runs evaluation on the latest model checkpoint.
"""Runs evaluation on the latest model checkpoint.
Args:
Args:
estimator: Instance of tf.Estimator.
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_steps: The number of steps for which to evaluate the model. If None,
eval_name is the name of the evaluation set (e.g. "train" or "val"),
evaluates until input_fn raises an end-of-input exception.
input_fn is an input function returning a tuple (features, labels), and
eval_name: Name of the evaluation set, e.g. "train" or "val".
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
Returns:
Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the
global_step: The global step of the checkpoint evaluated.
training job has not yet saved a checkpoint or the checkpoint is deleted by
values: A dict of metric values from the evaluation. May be empty, e.g. if
the time the TPU worker initializes.
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
"""
values
=
{}
# Default return value if evaluation fails.
# Default return values if evaluation fails.
global_step
=
None
values
=
{}
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
estimator
.
model_dir
)
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
estimator
.
model_dir
)
if
not
latest_checkpoint
:
if
not
latest_checkpoint
:
# This is expected if the training job has not yet saved a checkpoint.
# This is expected if the training job has not yet saved a checkpoint.
return
values
return
global_step
,
values
tf
.
logging
.
info
(
"Starting evaluation on checkpoint %s"
,
latest_checkpoint
)
tf
.
logging
.
info
(
"Starting evaluation on checkpoint %s"
,
latest_checkpoint
)
try
:
try
:
values
=
estimator
.
evaluate
(
input_fn
,
steps
=
eval_steps
,
name
=
eval_name
)
for
eval_name
,
(
input_fn
,
eval_steps
)
in
eval_args
.
items
():
values
[
eval_name
]
=
estimator
.
evaluate
(
input_fn
,
steps
=
eval_steps
,
name
=
eval_name
)
if
global_step
is
None
:
global_step
=
values
[
eval_name
].
get
(
"global_step"
)
except
tf
.
errors
.
NotFoundError
:
except
tf
.
errors
.
NotFoundError
:
# Expected under some conditions, e.g.
TPU worker does not finish
# Expected under some conditions, e.g.
checkpoint is already deleted by the
#
initializ
ing un
til long after the CPU job tells it to start evaluating
#
trainer process. Increas
ing
R
un
Config.keep_checkpoint_max may prevent this
#
and the checkpoint file is deleted already
.
#
in some cases
.
tf
.
logging
.
info
(
"Checkpoint %s no longer exists, skipping evaluation"
,
tf
.
logging
.
info
(
"Checkpoint %s no longer exists, skipping evaluation"
,
latest_checkpoint
)
latest_checkpoint
)
return
values
return
global_step
,
values
def
continuous_eval
(
estimator
,
def
continuous_eval
(
estimator
,
input_fn
,
eval_args
,
train_steps
=
None
,
train_steps
=
None
,
eval_step
s
=
None
,
timeout_sec
s
=
None
,
eval_name
=
"val"
):
timeout_fn
=
None
):
"""Runs evaluation whenever there's a new checkpoint.
"""Runs evaluation whenever there's a new checkpoint.
Args:
Args:
estimator: Instance of tf.Estimator.
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
If None, this
will terminate once the model has finished training.
function will run forever.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
eval_steps: The number of steps for which to evaluate the model. If None,
indefinitely.
evaluates until input_fn raises an end-of-input exception.
timeout_fn: Optional function to call after timeout. The iterator will exit
eval_name: Name of the evalua
tion
s
et
, e.g. "train" or "val"
.
if and only if the func
tion
r
et
urns True
.
Yields:
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
the time the TPU worker initializes.
"""
"""
for
_
in
tf
.
contrib
.
training
.
checkpoints_iterator
(
estimator
.
model_dir
):
for
_
in
tf
.
contrib
.
training
.
checkpoints_iterator
(
values
=
evaluate
(
estimator
,
input_fn
,
eval_steps
,
eval_name
)
estimator
.
model_dir
,
timeout
=
timeout_secs
,
timeout_fn
=
timeout_fn
):
yield
values
global_step
,
values
=
evaluate
(
estimator
,
eval_args
)
yield
global_step
,
values
global_step
=
values
.
get
(
"global_step"
,
0
)
global_step
=
global_step
or
0
# Ensure global_step is not None.
if
train_steps
and
global_step
>=
train_steps
:
if
train_steps
and
global_step
>=
train_steps
:
break
break
def
continuous_train_and_eval
(
estimator
,
def
continuous_train_and_eval
(
estimator
,
train_input_fn
,
train_input_fn
,
eval_
input_fn
,
eval_
args
,
local_eval_frequency
=
None
,
local_eval_frequency
=
None
,
train_hooks
=
None
,
train_hooks
=
None
,
train_steps
=
None
,
train_steps
=
None
):
eval_steps
=
None
,
eval_name
=
"val"
):
"""Alternates training and evaluation.
"""Alternates training and evaluation.
Args:
Args:
estimator: Instance of tf.Estimator.
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where
eval_name is the name of the evaluation set (e.g. "train" or "val"),
input_fn is an input function returning a tuple (features, labels), and
eval_steps is the number of steps for which to evaluate the model (if
None, evaluates until input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
inside the training call.
train_steps: The total number of steps to train the model for.
train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
A dict of metric values from each evaluation. May be empty, e.g. if the
...
@@ -301,10 +313,10 @@ def continuous_train_and_eval(estimator,
...
@@ -301,10 +313,10 @@ def continuous_train_and_eval(estimator,
while
True
:
while
True
:
# We run evaluation before training in this loop to prevent evaluation from
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
# being skipped if the process is interrupted.
values
=
evaluate
(
estimator
,
eval_
input_fn
,
eval_steps
,
eval_name
)
global_step
,
values
=
evaluate
(
estimator
,
eval_
args
)
yield
values
yield
global_step
,
values
global_step
=
values
.
get
(
"global_step"
,
0
)
global_step
=
global_step
or
0
# Ensure global_step is not None.
if
train_steps
and
global_step
>=
train_steps
:
if
train_steps
and
global_step
>=
train_steps
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment