Commit 2a6e342e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Move shared TensorFlow utilities from astronet/util to tf_util/.

PiperOrigin-RevId: 223433850
parent 03612984
...@@ -13,7 +13,7 @@ py_library( ...@@ -13,7 +13,7 @@ py_library(
"//astronet/astro_fc_model:configurations", "//astronet/astro_fc_model:configurations",
"//astronet/astro_model", "//astronet/astro_model",
"//astronet/astro_model:configurations", "//astronet/astro_model:configurations",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -23,10 +23,10 @@ py_binary( ...@@ -23,10 +23,10 @@ py_binary(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":models", ":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util", "//astronet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
], ],
) )
...@@ -36,10 +36,10 @@ py_binary( ...@@ -36,10 +36,10 @@ py_binary(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":models", ":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util", "//astronet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
], ],
) )
...@@ -50,8 +50,8 @@ py_binary( ...@@ -50,8 +50,8 @@ py_binary(
deps = [ deps = [
":models", ":models",
"//astronet/data:preprocess", "//astronet/data:preprocess",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_util", "//astronet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
], ],
) )
...@@ -32,6 +32,6 @@ py_test( ...@@ -32,6 +32,6 @@ py_test(
":configurations", ":configurations",
"//astronet/ops:input_ops", "//astronet/ops:input_ops",
"//astronet/ops:testing", "//astronet/ops:testing",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -25,7 +25,7 @@ from astronet.astro_cnn_model import astro_cnn_model ...@@ -25,7 +25,7 @@ from astronet.astro_cnn_model import astro_cnn_model
from astronet.astro_cnn_model import configurations from astronet.astro_cnn_model import configurations
from astronet.ops import input_ops from astronet.ops import input_ops
from astronet.ops import testing from astronet.ops import testing
from astronet.util import configdict from tf_util import configdict
class AstroCNNModelTest(tf.test.TestCase): class AstroCNNModelTest(tf.test.TestCase):
......
...@@ -32,6 +32,6 @@ py_test( ...@@ -32,6 +32,6 @@ py_test(
":configurations", ":configurations",
"//astronet/ops:input_ops", "//astronet/ops:input_ops",
"//astronet/ops:testing", "//astronet/ops:testing",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -25,7 +25,7 @@ from astronet.astro_fc_model import astro_fc_model ...@@ -25,7 +25,7 @@ from astronet.astro_fc_model import astro_fc_model
from astronet.astro_fc_model import configurations from astronet.astro_fc_model import configurations
from astronet.ops import input_ops from astronet.ops import input_ops
from astronet.ops import testing from astronet.ops import testing
from astronet.util import configdict from tf_util import configdict
class AstroFCModelTest(tf.test.TestCase): class AstroFCModelTest(tf.test.TestCase):
......
...@@ -28,6 +28,6 @@ py_test( ...@@ -28,6 +28,6 @@ py_test(
":configurations", ":configurations",
"//astronet/ops:input_ops", "//astronet/ops:input_ops",
"//astronet/ops:testing", "//astronet/ops:testing",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -25,7 +25,7 @@ from astronet.astro_model import astro_model ...@@ -25,7 +25,7 @@ from astronet.astro_model import astro_model
from astronet.astro_model import configurations from astronet.astro_model import configurations
from astronet.ops import input_ops from astronet.ops import input_ops
from astronet.ops import testing from astronet.ops import testing
from astronet.util import configdict from tf_util import configdict
class AstroModelTest(tf.test.TestCase): class AstroModelTest(tf.test.TestCase):
......
...@@ -12,10 +12,10 @@ py_library( ...@@ -12,10 +12,10 @@ py_library(
name = "preprocess", name = "preprocess",
srcs = ["preprocess.py"], srcs = ["preprocess.py"],
deps = [ deps = [
"//astronet/util:example_util", "//light_curve:kepler_io",
"//light_curve_util:kepler_io", "//light_curve:median_filter",
"//light_curve_util:median_filter", "//light_curve:util",
"//light_curve_util:util", "//tf_util:example_util",
"//third_party/kepler_spline", "//third_party/kepler_spline",
], ],
) )
...@@ -21,10 +21,10 @@ from __future__ import print_function ...@@ -21,10 +21,10 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import example_util from light_curve import kepler_io
from light_curve_util import kepler_io from light_curve import median_filter
from light_curve_util import median_filter from light_curve import util
from light_curve_util import util from tf_util import example_util
from third_party.kepler_spline import kepler_spline from third_party.kepler_spline import kepler_spline
......
...@@ -24,10 +24,10 @@ import sys ...@@ -24,10 +24,10 @@ import sys
import tensorflow as tf import tensorflow as tf
from astronet import models from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -15,7 +15,7 @@ py_test( ...@@ -15,7 +15,7 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":input_ops", ":input_ops",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -33,7 +33,7 @@ py_test( ...@@ -33,7 +33,7 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":dataset_ops", ":dataset_ops",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
......
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.ops import dataset_ops from astronet.ops import dataset_ops
from astronet.util import configdict from tf_util import configdict
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from astronet.ops import input_ops from astronet.ops import input_ops
from astronet.util import configdict from tf_util import configdict
class InputOpsTest(tf.test.TestCase): class InputOpsTest(tf.test.TestCase):
......
...@@ -27,9 +27,9 @@ import tensorflow as tf ...@@ -27,9 +27,9 @@ import tensorflow as tf
from astronet import models from astronet import models
from astronet.data import preprocess from astronet.data import preprocess
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -102,8 +102,9 @@ def _process_tce(feature_config): ...@@ -102,8 +102,9 @@ def _process_tce(feature_config):
"Only 'global_view' and 'local_view' features are supported.") "Only 'global_view' and 'local_view' features are supported.")
# Read and process the light curve. # Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(FLAGS.kepler_id, all_time, all_flux = preprocess.read_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir) FLAGS.kepler_data_dir)
time, flux = preprocess.process_light_curve(all_time, all_flux)
time, flux = preprocess.phase_fold_and_sort_light_curve( time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, FLAGS.period, FLAGS.t0) time, flux, FLAGS.period, FLAGS.t0)
...@@ -158,11 +159,7 @@ def main(_): ...@@ -158,11 +159,7 @@ def main(_):
# Create an input function. # Create an input function.
def input_fn(): def input_fn():
return { return tf.data.Dataset.from_tensors({"time_series_features": features})
"time_series_features":
tf.estimator.inputs.numpy_input_fn(
features, batch_size=1, shuffle=False, queue_capacity=1)()
}
# Generate the predictions. # Generate the predictions.
for predictions in estimator.predict(input_fn): for predictions in estimator.predict(input_fn):
......
...@@ -24,10 +24,10 @@ import sys ...@@ -24,10 +24,10 @@ import sys
import tensorflow as tf import tensorflow as tf
from astronet import models from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -68,7 +68,7 @@ parser.add_argument( ...@@ -68,7 +68,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--train_steps", "--train_steps",
type=int, type=int,
default=10000, default=625,
help="Total number of steps to train the model for.") help="Total number of steps to train the model for.")
parser.add_argument( parser.add_argument(
......
...@@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"]) ...@@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library( py_library(
name = "estimator_util", name = "estimator_util",
srcs = ["estimator_util.py"], srcs = ["estimator_util.py"],
...@@ -48,18 +12,3 @@ py_library( ...@@ -48,18 +12,3 @@ py_library(
"//astronet/ops:training", "//astronet/ops:training",
], ],
) )
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
...@@ -11,12 +11,12 @@ py_binary( ...@@ -11,12 +11,12 @@ py_binary(
deps = [ deps = [
":astrowavenet_model", ":astrowavenet_model",
":configurations", ":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves", "//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits", "//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util", "//astrowavenet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
], ],
) )
...@@ -44,6 +44,6 @@ py_test( ...@@ -44,6 +44,6 @@ py_test(
deps = [ deps = [
":astrowavenet_model", ":astrowavenet_model",
":configurations", ":configurations",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -21,8 +21,8 @@ from __future__ import print_function ...@@ -21,8 +21,8 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import configdict
from astrowavenet import astrowavenet_model from astrowavenet import astrowavenet_model
from tf_util import configdict
class AstrowavenetTest(tf.test.TestCase): class AstrowavenetTest(tf.test.TestCase):
......
...@@ -9,7 +9,7 @@ py_library( ...@@ -9,7 +9,7 @@ py_library(
], ],
deps = [ deps = [
"//astronet/ops:dataset_ops", "//astronet/ops:dataset_ops",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -28,7 +28,7 @@ py_library( ...@@ -28,7 +28,7 @@ py_library(
], ],
deps = [ deps = [
":base", ":base",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -40,7 +40,7 @@ py_library( ...@@ -40,7 +40,7 @@ py_library(
deps = [ deps = [
":base", ":base",
":synthetic_transit_maker", ":synthetic_transit_maker",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
......
...@@ -23,7 +23,7 @@ import six ...@@ -23,7 +23,7 @@ import six
import tensorflow as tf import tensorflow as tf
from astronet.util import configdict from tf_util import configdict
from astronet.ops import dataset_ops from astronet.ops import dataset_ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment