Commit 313d0c41 authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Refactor estimator_util.{create_input_fn,create_model_fn} to use a callable class object.

PiperOrigin-RevId: 212909744
parent bfa9364a
...@@ -27,71 +27,104 @@ from astronet.ops import metrics ...@@ -27,71 +27,104 @@ from astronet.ops import metrics
from astronet.ops import training from astronet.ops import training
def create_input_fn(file_pattern, class _InputFn(object):
input_config, """Class that acts as a callable input function for Estimator train / eval."""
mode,
shuffle_values_buffer=0, def __init__(self,
repeat=1): file_pattern,
"""Creates an input_fn that reads a dataset from sharded TFRecord files. input_config,
mode,
Args: shuffle_values_buffer=0,
file_pattern: File pattern matching input TFRecord files, e.g. repeat=1):
"""Initializes the input function.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file "/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns. patterns.
input_config: ConfigDict containing feature and label specifications. input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys. mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size. shuffle_values_buffer: If > 0, shuffle examples using a buffer of this
repeat: The number of times to repeat the dataset. If None or -1 the size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely. elements will be repeated indefinitely.
"""
Returns: self._file_pattern = file_pattern
A callable that builds an input pipeline and returns (features, labels). self._input_config = input_config
""" self._mode = mode
include_labels = ( self._shuffle_values_buffer = shuffle_values_buffer
mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]) self._repeat = repeat
reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN) def __call__(self, config, params):
"""Builds the input pipeline."""
def input_fn(config, params):
"""Builds an input pipeline that reads a dataset from TFRecord files."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using # Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type. # the config type.
use_tpu = isinstance(config, tf.contrib.tpu.RunConfig) use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)
mode = self._mode
include_labels = (
mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL])
reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN)
dataset = dataset_ops.build_dataset( dataset = dataset_ops.build_dataset(
file_pattern=file_pattern, file_pattern=self._file_pattern,
input_config=input_config, input_config=self._input_config,
batch_size=params["batch_size"], batch_size=params["batch_size"],
include_labels=include_labels, include_labels=include_labels,
reverse_time_series_prob=reverse_time_series_prob, reverse_time_series_prob=reverse_time_series_prob,
shuffle_filenames=shuffle_filenames, shuffle_filenames=shuffle_filenames,
shuffle_values_buffer=shuffle_values_buffer, shuffle_values_buffer=self._shuffle_values_buffer,
repeat=repeat, repeat=self._repeat,
use_tpu=use_tpu) use_tpu=use_tpu)
return dataset return dataset
return input_fn
def create_model_fn(model_class, hparams, use_tpu=False): def create_input_fn(file_pattern,
"""Wraps model_class as an Estimator or TPUEstimator model_fn. input_config,
mode,
shuffle_values_buffer=0,
repeat=1):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args: Args:
model_class: AstroModel or a subclass. file_pattern: File pattern matching input TFRecord files, e.g.
hparams: ConfigDict of configuration parameters for building the model. "/tmp/train-?????-of-00100". May also be a comma-separated list of file
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an patterns.
Estimator model_fn is returned. input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns: Returns:
model_fn: A callable that constructs the model and returns a A callable that builds the input pipeline and returns a tf.data.Dataset
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec. object.
""" """
hparams = copy.deepcopy(hparams) return _InputFn(file_pattern, input_config, mode, shuffle_values_buffer,
repeat)
class _ModelFn(object):
"""Class that acts as a callable model function for Estimator train / eval."""
def __init__(self, model_class, hparams, use_tpu=False):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: ConfigDict containing hyperparameters for building and training
the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self._model_class = model_class
self._base_hparams = hparams
self._use_tpu = use_tpu
def model_fn(features, labels, mode, params): def __call__(self, features, labels, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec.""" """Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
# For TPUEstimator, params contains the batch size per TPU core. hparams = copy.deepcopy(self._base_hparams)
if "batch_size" in params: if "batch_size" in params:
hparams.batch_size = params["batch_size"] hparams.batch_size = params["batch_size"]
...@@ -103,10 +136,11 @@ def create_model_fn(model_class, hparams, use_tpu=False): ...@@ -103,10 +136,11 @@ def create_model_fn(model_class, hparams, use_tpu=False):
(features["labels"], labels)) (features["labels"], labels))
labels = features.pop("labels") labels = features.pop("labels")
model = model_class(features, labels, hparams, mode) model = self._model_class(features, labels, hparams, mode)
model.build() model.build()
# Possibly create train_op. # Possibly create train_op.
use_tpu = self._use_tpu
train_op = None train_op = None
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step) learning_rate = training.create_learning_rate(hparams, model.global_step)
...@@ -137,7 +171,21 @@ def create_model_fn(model_class, hparams, use_tpu=False): ...@@ -137,7 +171,21 @@ def create_model_fn(model_class, hparams, use_tpu=False):
return estimator return estimator
return model_fn
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return _ModelFn(model_class, hparams, use_tpu)
def create_estimator(model_class, def create_estimator(model_class,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment