Commit d8588a7e authored by Toby Boyd's avatar Toby Boyd
Browse files

Merge branch 'cmlesupport' of https://github.com/elibixby/models

parents 5d06cfcf 423fd778
......@@ -29,6 +29,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import collections
import functools
import itertools
import os
......@@ -352,7 +353,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
# Adding hooks to be used by the estimator on training mode.
# Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks)
return experiment
return _experiment_fn
......@@ -379,7 +380,7 @@ def main(job_dir,
)
)
config = tf.contrib.learn.RunConfig(
config = cifar10_utils.RunConfig(
session_config=sess_config,
model_dir=job_dir)
tf.contrib.learn.learn_runner.run(
......
......@@ -8,6 +8,43 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training import device_setter
from tensorflow.contrib.learn.python.learn import run_config
# TODO(b/64848083) Remove once uid bug is fixed
class RunConfig(tf.contrib.learn.RunConfig):
def uid(self, whitelist=None):
"""Generates a 'Unique Identifier' based on all internal fields.
Caller should use the uid string to check `RunConfig` instance integrity
in one session use, but should not rely on the implementation details, which
is subject to change.
Args:
whitelist: A list of the string names of the properties uid should not
include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
includes most properties user allowes to change.
Returns:
A uid string.
"""
if whitelist is None:
whitelist = run_config._DEFAULT_UID_WHITE_LIST
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
# Pop out the keys in whitelist.
for k in whitelist:
state.pop('_' + k, None)
ordered_state = collections.OrderedDict(
sorted(state.items(), key=lambda t: t[0]))
# For class instance without __repr__, some special cares are required.
# Otherwise, the object address will be used.
if '_cluster_spec' in ordered_state:
ordered_state['_cluster_spec'] = collections.OrderedDict(
sorted(ordered_state['_cluster_spec'].as_dict().items(),
key=lambda t: t[0])
)
return ', '.join(
'%s=%r' % (k, v) for (k, v) in six.iteritems(ordered_state))
class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
"""Hook to print out examples per second.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment