Commit 2a6e342e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Move shared TensorFlow utilities from astronet/util to tf_util/.

PiperOrigin-RevId: 223433850
parent 03612984
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import configdict from tf_util import configdict
from astrowavenet.data import base from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker from astrowavenet.data import synthetic_transit_maker
......
...@@ -24,15 +24,14 @@ import os.path ...@@ -24,15 +24,14 @@ import os.path
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model from astrowavenet import astrowavenet_model
from astrowavenet import configurations from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util from astrowavenet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from astronet.util import config_util from tf_util import config_util
class ConfigUtilTest(tf.test.TestCase): class ConfigUtilTest(tf.test.TestCase):
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
from astronet.util import configdict from tf_util import configdict
class ConfigDictTest(absltest.TestCase): class ConfigDictTest(absltest.TestCase):
......
...@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args): ...@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args):
latest_checkpoint = estimator.latest_checkpoint() latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint: if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint. # This is expected if the training job has not yet saved a checkpoint.
tf.logging.info("No checkpoint in %s, skipping evaluation.",
estimator.model_dir)
return global_step, values return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint) tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
...@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args): ...@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args):
input_fn, steps=eval_steps, name=eval_name) input_fn, steps=eval_steps, name=eval_name)
if global_step is None: if global_step is None:
global_step = values[eval_name].get("global_step") global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError): except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the # Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this # trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases. # in some cases.
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import example_util from tf_util import example_util
class ExampleUtilTest(tf.test.TestCase): class ExampleUtilTest(tf.test.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment