"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "8dcfe9adacb8302f85935f62ef55325b60a62a06"
Commit e00e0e13 authored by dreamdragon's avatar dreamdragon
Browse files

Merge remote-tracking branch 'upstream/master'

parents b915db4e 402b561b
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from astronet.ops import input_ops from astronet.ops import input_ops
from astronet.util import configdict from tf_util import configdict
class InputOpsTest(tf.test.TestCase): class InputOpsTest(tf.test.TestCase):
......
...@@ -27,9 +27,9 @@ import tensorflow as tf ...@@ -27,9 +27,9 @@ import tensorflow as tf
from astronet import models from astronet import models
from astronet.data import preprocess from astronet.data import preprocess
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -102,8 +102,9 @@ def _process_tce(feature_config): ...@@ -102,8 +102,9 @@ def _process_tce(feature_config):
"Only 'global_view' and 'local_view' features are supported.") "Only 'global_view' and 'local_view' features are supported.")
# Read and process the light curve. # Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(FLAGS.kepler_id, all_time, all_flux = preprocess.read_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir) FLAGS.kepler_data_dir)
time, flux = preprocess.process_light_curve(all_time, all_flux)
time, flux = preprocess.phase_fold_and_sort_light_curve( time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, FLAGS.period, FLAGS.t0) time, flux, FLAGS.period, FLAGS.t0)
...@@ -158,11 +159,7 @@ def main(_): ...@@ -158,11 +159,7 @@ def main(_):
# Create an input function. # Create an input function.
def input_fn(): def input_fn():
return { return tf.data.Dataset.from_tensors({"time_series_features": features})
"time_series_features":
tf.estimator.inputs.numpy_input_fn(
features, batch_size=1, shuffle=False, queue_capacity=1)()
}
# Generate the predictions. # Generate the predictions.
for predictions in estimator.predict(input_fn): for predictions in estimator.predict(input_fn):
......
...@@ -24,10 +24,10 @@ import sys ...@@ -24,10 +24,10 @@ import sys
import tensorflow as tf import tensorflow as tf
from astronet import models from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -68,7 +68,7 @@ parser.add_argument( ...@@ -68,7 +68,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--train_steps", "--train_steps",
type=int, type=int,
default=10000, default=625,
help="Total number of steps to train the model for.") help="Total number of steps to train the model for.")
parser.add_argument( parser.add_argument(
......
...@@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"]) ...@@ -2,42 +2,6 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library( py_library(
name = "estimator_util", name = "estimator_util",
srcs = ["estimator_util.py"], srcs = ["estimator_util.py"],
...@@ -48,18 +12,3 @@ py_library( ...@@ -48,18 +12,3 @@ py_library(
"//astronet/ops:training", "//astronet/ops:training",
], ],
) )
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
...@@ -11,12 +11,12 @@ py_binary( ...@@ -11,12 +11,12 @@ py_binary(
deps = [ deps = [
":astrowavenet_model", ":astrowavenet_model",
":configurations", ":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves", "//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits", "//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util", "//astrowavenet/util:estimator_util",
"//tf_util:config_util",
"//tf_util:configdict",
"//tf_util:estimator_runner",
], ],
) )
...@@ -44,6 +44,6 @@ py_test( ...@@ -44,6 +44,6 @@ py_test(
deps = [ deps = [
":astrowavenet_model", ":astrowavenet_model",
":configurations", ":configurations",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -15,13 +15,10 @@ Chris Shallue: [@cshallue](https://github.com/cshallue) ...@@ -15,13 +15,10 @@ Chris Shallue: [@cshallue](https://github.com/cshallue)
## Additional Dependencies ## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this In addition to the [required packages](../README.md#required-packages) listed in
requires the **TensorFlow nightly build** the top-level README, this package requires:
([instructions](https://www.tensorflow.org/install/pip)).
In addition to the dependencies listed in the top-level README, this package
requires:
* **TensorFlow 1.12 or greater** ([instructions](https://www.tensorflow.org/install/))
* **TensorFlow Probability** ([instructions](https://www.tensorflow.org/probability/install)) * **TensorFlow Probability** ([instructions](https://www.tensorflow.org/probability/install))
* **Six** ([instructions](https://pypi.org/project/six/)) * **Six** ([instructions](https://pypi.org/project/six/))
......
...@@ -21,8 +21,8 @@ from __future__ import print_function ...@@ -21,8 +21,8 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import configdict
from astrowavenet import astrowavenet_model from astrowavenet import astrowavenet_model
from tf_util import configdict
class AstrowavenetTest(tf.test.TestCase): class AstrowavenetTest(tf.test.TestCase):
......
...@@ -9,7 +9,7 @@ py_library( ...@@ -9,7 +9,7 @@ py_library(
], ],
deps = [ deps = [
"//astronet/ops:dataset_ops", "//astronet/ops:dataset_ops",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -28,7 +28,7 @@ py_library( ...@@ -28,7 +28,7 @@ py_library(
], ],
deps = [ deps = [
":base", ":base",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
...@@ -40,7 +40,7 @@ py_library( ...@@ -40,7 +40,7 @@ py_library(
deps = [ deps = [
":base", ":base",
":synthetic_transit_maker", ":synthetic_transit_maker",
"//astronet/util:configdict", "//tf_util:configdict",
], ],
) )
......
...@@ -23,7 +23,7 @@ import six ...@@ -23,7 +23,7 @@ import six
import tensorflow as tf import tensorflow as tf
from astronet.util import configdict from tf_util import configdict
from astronet.ops import dataset_ops from astronet.ops import dataset_ops
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from astronet.util import configdict from tf_util import configdict
from astrowavenet.data import base from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker from astrowavenet.data import synthetic_transit_maker
......
...@@ -24,15 +24,14 @@ import os.path ...@@ -24,15 +24,14 @@ import os.path
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model from astrowavenet import astrowavenet_model
from astrowavenet import configurations from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util from astrowavenet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
# Light Curve Operations
## Code Author
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Python modules
* `kepler_io`: Functions for reading Kepler data.
* `median_filter`: Utility for smoothing data using a median filter.
* `periodic_event`: Event class, which represents a periodic event in a light curve.
* `util`: Light curve utility functions.
## Fast ops
The [fast_ops](fast_ops/) subdirectory contains optimized C++ light curve
operations. These operations can be compiled for Python using
[CLIF](https://github.com/google/clif). The [fast_ops/python](fast_ops/python/)
directory contains CLIF API description files.
...@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and ...@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_ #ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_ #define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_H_
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
...@@ -70,4 +70,4 @@ typename std::iterator_traits<ForwardIterator>::value_type Median( ...@@ -70,4 +70,4 @@ typename std::iterator_traits<ForwardIterator>::value_type Median(
} // namespace astronet } // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_ #endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_H_
...@@ -12,10 +12,10 @@ See the License for the specific language governing permissions and ...@@ -12,10 +12,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "light_curve_util/cc/median_filter.h" #include "light_curve/fast_ops/median_filter.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "light_curve_util/cc/median.h" #include "light_curve/fast_ops/median.h"
using absl::Substitute; using absl::Substitute;
using std::min; using std::min;
......
...@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and ...@@ -12,8 +12,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_ #ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_FILTER_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_ #define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_FILTER_H_
#include <iostream> #include <iostream>
...@@ -56,4 +56,4 @@ bool MedianFilter(const std::vector<double>& x, const std::vector<double>& y, ...@@ -56,4 +56,4 @@ bool MedianFilter(const std::vector<double>& x, const std::vector<double>& y,
} // namespace astronet } // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_ #endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_MEDIAN_FILTER_H_
...@@ -12,11 +12,11 @@ See the License for the specific language governing permissions and ...@@ -12,11 +12,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "light_curve_util/cc/median_filter.h" #include "light_curve/fast_ops/median_filter.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h" #include "light_curve/fast_ops/test_util.h"
using std::vector; using std::vector;
using testing::Pointwise; using testing::Pointwise;
......
...@@ -12,7 +12,7 @@ See the License for the specific language governing permissions and ...@@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "light_curve_util/cc/median.h" #include "light_curve/fast_ops/median.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment