Commit 2a6e342e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Move shared TensorFlow utilities from astronet/util to tf_util/.

PiperOrigin-RevId: 223433850
parent 03612984
......@@ -13,7 +13,7 @@ py_library(
"//astronet/astro_fc_model:configurations",
"//astronet/astro_model",
"//astronet/astro_model:configurations",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -23,10 +23,10 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
],
)
......@@ -36,10 +36,10 @@ py_binary(
srcs_version = "PY2AND3",
deps = [
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
],
)
......@@ -50,8 +50,8 @@ py_binary(
deps = [
":models",
"//astronet/data:preprocess",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
],
)
......@@ -32,6 +32,6 @@ py_test(
":configurations",
"//astronet/ops:input_ops",
"//astronet/ops:testing",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -25,7 +25,7 @@ from astronet.astro_cnn_model import astro_cnn_model
from astronet.astro_cnn_model import configurations
from astronet.ops import input_ops
from astronet.ops import testing
from astronet.util import configdict
from tf_util import configdict
class AstroCNNModelTest(tf.test.TestCase):
......
......@@ -32,6 +32,6 @@ py_test(
":configurations",
"//astronet/ops:input_ops",
"//astronet/ops:testing",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -25,7 +25,7 @@ from astronet.astro_fc_model import astro_fc_model
from astronet.astro_fc_model import configurations
from astronet.ops import input_ops
from astronet.ops import testing
from astronet.util import configdict
from tf_util import configdict
class AstroFCModelTest(tf.test.TestCase):
......
......@@ -28,6 +28,6 @@ py_test(
":configurations",
"//astronet/ops:input_ops",
"//astronet/ops:testing",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -25,7 +25,7 @@ from astronet.astro_model import astro_model
from astronet.astro_model import configurations
from astronet.ops import input_ops
from astronet.ops import testing
from astronet.util import configdict
from tf_util import configdict
class AstroModelTest(tf.test.TestCase):
......
......@@ -12,10 +12,10 @@ py_library(
name = "preprocess",
srcs = ["preprocess.py"],
deps = [
"//astronet/util:example_util",
"//light_curve_util:kepler_io",
"//light_curve_util:median_filter",
"//light_curve_util:util",
"//light_curve:kepler_io",
"//light_curve:median_filter",
"//light_curve:util",
"//tf_util:example_util",
"//third_party/kepler_spline",
],
)
......@@ -21,10 +21,10 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import example_util
from light_curve_util import kepler_io
from light_curve_util import median_filter
from light_curve_util import util
from light_curve import kepler_io
from light_curve import median_filter
from light_curve import util
from tf_util import example_util
from third_party.kepler_spline import kepler_spline
......
......@@ -24,10 +24,10 @@ import sys
import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser()
......
......@@ -15,7 +15,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":input_ops",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -33,7 +33,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......
......@@ -25,7 +25,7 @@ import numpy as np
import tensorflow as tf
from astronet.ops import dataset_ops
from astronet.util import configdict
from tf_util import configdict
FLAGS = flags.FLAGS
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from astronet.ops import input_ops
from astronet.util import configdict
from tf_util import configdict
class InputOpsTest(tf.test.TestCase):
......
......@@ -27,9 +27,9 @@ import tensorflow as tf
from astronet import models
from astronet.data import preprocess
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
parser = argparse.ArgumentParser()
......@@ -102,8 +102,9 @@ def _process_tce(feature_config):
"Only 'global_view' and 'local_view' features are supported.")
# Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir)
all_time, all_flux = preprocess.read_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir)
time, flux = preprocess.process_light_curve(all_time, all_flux)
time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, FLAGS.period, FLAGS.t0)
......@@ -158,11 +159,7 @@ def main(_):
# Create an input function.
def input_fn():
return {
"time_series_features":
tf.estimator.inputs.numpy_input_fn(
features, batch_size=1, shuffle=False, queue_capacity=1)()
}
return tf.data.Dataset.from_tensors({"time_series_features": features})
# Generate the predictions.
for predictions in estimator.predict(input_fn):
......
......@@ -24,10 +24,10 @@ import sys
import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser()
......@@ -68,7 +68,7 @@ parser.add_argument(
parser.add_argument(
"--train_steps",
type=int,
default=10000,
default=625,
help="Total number of steps to train the model for.")
parser.add_argument(
......
......@@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
......@@ -48,18 +12,3 @@ py_library(
"//astronet/ops:training",
],
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
......@@ -11,12 +11,12 @@ py_binary(
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
],
)
......@@ -44,6 +44,6 @@ py_test(
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -21,8 +21,8 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet import astrowavenet_model
from tf_util import configdict
class AstrowavenetTest(tf.test.TestCase):
......
......@@ -9,7 +9,7 @@ py_library(
],
deps = [
"//astronet/ops:dataset_ops",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -28,7 +28,7 @@ py_library(
],
deps = [
":base",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -40,7 +40,7 @@ py_library(
deps = [
":base",
":synthetic_transit_maker",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......
......@@ -23,7 +23,7 @@ import six
import tensorflow as tf
from astronet.util import configdict
from tf_util import configdict
from astronet.ops import dataset_ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment