"vscode:/vscode.git/clone" did not exist on "15430ccc2f514fb6c12568614b70740e1bed1bfc"
Commit e00e0e13 authored by dreamdragon's avatar dreamdragon
Browse files

Merge remote-tracking branch 'upstream/master'

parents b915db4e 402b561b
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from astronet.ops import input_ops
from astronet.util import configdict
from tf_util import configdict
class InputOpsTest(tf.test.TestCase):
......
......@@ -27,9 +27,9 @@ import tensorflow as tf
from astronet import models
from astronet.data import preprocess
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
parser = argparse.ArgumentParser()
......@@ -102,8 +102,9 @@ def _process_tce(feature_config):
"Only 'global_view' and 'local_view' features are supported.")
# Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir)
all_time, all_flux = preprocess.read_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir)
time, flux = preprocess.process_light_curve(all_time, all_flux)
time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, FLAGS.period, FLAGS.t0)
......@@ -158,11 +159,7 @@ def main(_):
# Create an input function.
def input_fn():
return {
"time_series_features":
tf.estimator.inputs.numpy_input_fn(
features, batch_size=1, shuffle=False, queue_capacity=1)()
}
return tf.data.Dataset.from_tensors({"time_series_features": features})
# Generate the predictions.
for predictions in estimator.predict(input_fn):
......
......@@ -24,10 +24,10 @@ import sys
import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser()
......@@ -68,7 +68,7 @@ parser.add_argument(
parser.add_argument(
"--train_steps",
type=int,
default=10000,
default=625,
help="Total number of steps to train the model for.")
parser.add_argument(
......
......@@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
......@@ -48,18 +12,3 @@ py_library(
"//astronet/ops:training",
],
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
......@@ -11,12 +11,12 @@ py_binary(
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
],
)
......@@ -44,6 +44,6 @@ py_test(
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -15,13 +15,10 @@ Chris Shallue: [@cshallue](https://github.com/cshallue)
## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this
requires the **TensorFlow nightly build**
([instructions](https://www.tensorflow.org/install/pip)).
In addition to the dependencies listed in the top-level README, this package
requires:
In addition to the [required packages](../README.md#required-packages) listed in
the top-level README, this package requires:
* **TensorFlow 1.12 or greater** ([instructions](https://www.tensorflow.org/install/))
* **TensorFlow Probability** ([instructions](https://www.tensorflow.org/probability/install))
* **Six** ([instructions](https://pypi.org/project/six/))
......
......@@ -21,8 +21,8 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet import astrowavenet_model
from tf_util import configdict
class AstrowavenetTest(tf.test.TestCase):
......
......@@ -9,7 +9,7 @@ py_library(
],
deps = [
"//astronet/ops:dataset_ops",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -28,7 +28,7 @@ py_library(
],
deps = [
":base",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......@@ -40,7 +40,7 @@ py_library(
deps = [
":base",
":synthetic_transit_maker",
"//astronet/util:configdict",
"//tf_util:configdict",
],
)
......
......@@ -23,7 +23,7 @@ import six
import tensorflow as tf
from astronet.util import configdict
from tf_util import configdict
from astronet.ops import dataset_ops
......
......@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from tf_util import configdict
from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker
......
......@@ -24,15 +24,14 @@ import os.path
from absl import flags
import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model
from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
FLAGS = flags.FLAGS
......
# Light Curve Operations
## Code Author
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Python modules
* `kepler_io`: Functions for reading Kepler data.
* `median_filter`: Utility for smoothing data using a median filter.
* `periodic_event`: Event class, which represents a periodic event in a light curve.
* `util`: Light curve utility functions.
## Fast ops
The [fast_ops](fast_ops/) subdirectory contains optimized C++ light curve
operations. These operations can be compiled for Python using
[CLIF](https://github.com/google/clif). The [fast_ops/python](fast_ops/python/)
directory contains CLIF API description files.
......@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_H_
#include <algorithm>
#include <iterator>
......@@ -70,4 +70,4 @@ typename std::iterator_traits<ForwardIterator>::value_type Median(
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_H_
......@@ -12,10 +12,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/median_filter.h"
#include "light_curve/fast_ops/median_filter.h"
#include "absl/strings/substitute.h"
#include "light_curve_util/cc/median.h"
#include "light_curve/fast_ops/median.h"
using absl::Substitute;
using std::min;
......
......@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_FILTER_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_FILTER_H_
#include <iostream>
......@@ -56,4 +56,4 @@ bool MedianFilter(const std::vector<double>& x, const std::vector<double>& y,
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_FILTER_H_
......@@ -12,11 +12,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/median_filter.h"
#include "light_curve/fast_ops/median_filter.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h"
#include "light_curve/fast_ops/test_util.h"
using std::vector;
using testing::Pointwise;
......
......@@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/median.h"
#include "light_curve/fast_ops/median.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment