Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
313d0c41
Commit
313d0c41
authored
Sep 13, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Refactor estimator_util.{create_input_fn,create_model_fn} to use a callable class object.
PiperOrigin-RevId: 212909744
parent
bfa9364a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
91 additions
and
43 deletions
+91
-43
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+91
-43
No files found.
research/astronet/astronet/util/estimator_util.py
View file @
313d0c41
...
...
@@ -27,71 +27,104 @@ from astronet.ops import metrics
from
astronet.ops
import
training
def
create_input_fn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
class
_InputFn
(
object
):
"""Class that acts as a callable input function for Estimator train / eval."""
def
__init__
(
self
,
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Initializes the input function.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this
size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
A callable that builds an input pipeline and returns (features, labels).
"""
include_labels
=
(
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
])
reverse_time_series_prob
=
0.5
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
0
shuffle_filenames
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
def
input_fn
(
config
,
params
):
"""Builds an input pipeline that reads a dataset from TFRecord files."""
"""
self
.
_file_pattern
=
file_pattern
self
.
_input_config
=
input_config
self
.
_mode
=
mode
self
.
_shuffle_values_buffer
=
shuffle_values_buffer
self
.
_repeat
=
repeat
def
__call__
(
self
,
config
,
params
):
"""Builds the input pipeline."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type.
use_tpu
=
isinstance
(
config
,
tf
.
contrib
.
tpu
.
RunConfig
)
mode
=
self
.
_mode
include_labels
=
(
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
])
reverse_time_series_prob
=
0.5
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
0
shuffle_filenames
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
file_pattern
,
input_config
=
input_config
,
file_pattern
=
self
.
_
file_pattern
,
input_config
=
self
.
_
input_config
,
batch_size
=
params
[
"batch_size"
],
include_labels
=
include_labels
,
reverse_time_series_prob
=
reverse_time_series_prob
,
shuffle_filenames
=
shuffle_filenames
,
shuffle_values_buffer
=
shuffle_values_buffer
,
repeat
=
repeat
,
shuffle_values_buffer
=
self
.
_
shuffle_values_buffer
,
repeat
=
self
.
_
repeat
,
use_tpu
=
use_tpu
)
return
dataset
return
input_fn
def
create_model_fn
(
model_class
,
hparams
,
use_tpu
=
False
):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
def
create_input_fn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
model_fn:
A callable that
constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec
.
A callable that
builds the input pipeline and returns a tf.data.Dataset
object
.
"""
hparams
=
copy
.
deepcopy
(
hparams
)
return
_InputFn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
,
repeat
)
class
_ModelFn
(
object
):
"""Class that acts as a callable model function for Estimator train / eval."""
def
__init__
(
self
,
model_class
,
hparams
,
use_tpu
=
False
):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: ConfigDict containing hyperparameters for building and training
the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self
.
_model_class
=
model_class
self
.
_base_hparams
=
hparams
self
.
_use_tpu
=
use_tpu
def
model_fn
(
features
,
labels
,
mode
,
params
):
def
__call__
(
self
,
features
,
labels
,
mode
,
params
):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
# For TPUEstimator, params contains the batch size per TPU core.
hparams
=
copy
.
deepcopy
(
self
.
_base_hparams
)
if
"batch_size"
in
params
:
hparams
.
batch_size
=
params
[
"batch_size"
]
...
...
@@ -103,10 +136,11 @@ def create_model_fn(model_class, hparams, use_tpu=False):
(
features
[
"labels"
],
labels
))
labels
=
features
.
pop
(
"labels"
)
model
=
model_class
(
features
,
labels
,
hparams
,
mode
)
model
=
self
.
_
model_class
(
features
,
labels
,
hparams
,
mode
)
model
.
build
()
# Possibly create train_op.
use_tpu
=
self
.
_use_tpu
train_op
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
learning_rate
=
training
.
create_learning_rate
(
hparams
,
model
.
global_step
)
...
...
@@ -137,7 +171,21 @@ def create_model_fn(model_class, hparams, use_tpu=False):
return
estimator
return
model_fn
def
create_model_fn
(
model_class
,
hparams
,
use_tpu
=
False
):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return
_ModelFn
(
model_class
,
hparams
,
use_tpu
)
def
create_estimator
(
model_class
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment