Unverified Commit 5324fc66 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5838 from cshallue/master

Reorganize astronet directory structure
parents 03612984 17c2f0cc
...@@ -12,11 +12,11 @@ See the License for the specific language governing permissions and ...@@ -12,11 +12,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "light_curve_util/cc/view_generator.h" #include "light_curve/fast_ops/view_generator.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h" #include "light_curve/fast_ops/test_util.h"
using std::vector; using std::vector;
using testing::Pointwise; using testing::Pointwise;
...@@ -49,7 +49,7 @@ TEST(ViewGenerator, GenerateViews) { ...@@ -49,7 +49,7 @@ TEST(ViewGenerator, GenerateViews) {
vector<double> result; vector<double> result;
// Error: t_max <= t_min. We do not test all failure cases here since they // Error: t_max <= t_min. We do not test all failure cases here since they
// are tested in light_curve_util_test.cc. // are covered by the median filter's tests.
EXPECT_FALSE(generator->GenerateView(10, 1, -1, -1, false, &result, &error)); EXPECT_FALSE(generator->GenerateView(10, 1, -1, -1, false, &result, &error));
EXPECT_FALSE(error.empty()); EXPECT_FALSE(error.empty());
error.clear(); error.clear();
......
...@@ -23,7 +23,7 @@ import os.path ...@@ -23,7 +23,7 @@ import os.path
from astropy.io import fits from astropy.io import fits
import numpy as np import numpy as np
from light_curve_util import util from light_curve import util
from tensorflow import gfile from tensorflow import gfile
# Quarter index to filename prefix for long cadence Kepler data. # Quarter index to filename prefix for long cadence Kepler data.
......
...@@ -24,16 +24,17 @@ from absl import flags ...@@ -24,16 +24,17 @@ from absl import flags
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from light_curve_util import kepler_io from light_curve import kepler_io
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
_DATA_DIR = "light_curve_util/test_data/" _DATA_DIR = "light_curve/test_data/"
class KeplerIoTest(absltest.TestCase): class KeplerIoTest(absltest.TestCase):
def setUp(self): def setUp(self):
super(KeplerIoTest, self).setUp()
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR) self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testScrambleLightCurve(self): def testScrambleLightCurve(self):
...@@ -49,8 +50,8 @@ class KeplerIoTest(absltest.TestCase): ...@@ -49,8 +50,8 @@ class KeplerIoTest(absltest.TestCase):
gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]] gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]]
gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]] gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]]
self.assertEqual(len(gold_flux), len(scr_flux)) self.assertLen(gold_flux, len(scr_flux))
self.assertEqual(len(gold_time), len(scr_time)) self.assertLen(gold_time, len(scr_time))
for i in range(len(gold_flux)): for i in range(len(gold_flux)):
np.testing.assert_array_equal(gold_flux[i], scr_flux[i]) np.testing.assert_array_equal(gold_flux[i], scr_flux[i])
...@@ -60,7 +61,7 @@ class KeplerIoTest(absltest.TestCase): ...@@ -60,7 +61,7 @@ class KeplerIoTest(absltest.TestCase):
# All quarters. # All quarters.
filenames = kepler_io.kepler_filenames( filenames = kepler_io.kepler_filenames(
"/my/dir/", 1234567, check_existence=False) "/my/dir/", 1234567, check_existence=False)
self.assertItemsEqual([ self.assertCountEqual([
"/my/dir/0012/001234567/kplr001234567-2009131105131_llc.fits", "/my/dir/0012/001234567/kplr001234567-2009131105131_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009166043257_llc.fits", "/my/dir/0012/001234567/kplr001234567-2009166043257_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009259160929_llc.fits", "/my/dir/0012/001234567/kplr001234567-2009259160929_llc.fits",
...@@ -85,7 +86,7 @@ class KeplerIoTest(absltest.TestCase): ...@@ -85,7 +86,7 @@ class KeplerIoTest(absltest.TestCase):
# Subset of quarters. # Subset of quarters.
filenames = kepler_io.kepler_filenames( filenames = kepler_io.kepler_filenames(
"/my/dir/", 1234567, quarters=[3, 4], check_existence=False) "/my/dir/", 1234567, quarters=[3, 4], check_existence=False)
self.assertItemsEqual([ self.assertCountEqual([
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits", "/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits", "/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits" "/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits"
...@@ -99,7 +100,7 @@ class KeplerIoTest(absltest.TestCase): ...@@ -99,7 +100,7 @@ class KeplerIoTest(absltest.TestCase):
injected_group="inj1", injected_group="inj1",
check_existence=False) check_existence=False)
# pylint:disable=line-too-long # pylint:disable=line-too-long
self.assertItemsEqual([ self.assertCountEqual([
"/my/dir/0012/001234567/kplr001234567-2009350155506_INJECTED-inj1_llc.fits", "/my/dir/0012/001234567/kplr001234567-2009350155506_INJECTED-inj1_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_INJECTED-inj1_llc.fits", "/my/dir/0012/001234567/kplr001234567-2010078095331_INJECTED-inj1_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_INJECTED-inj1_llc.fits" "/my/dir/0012/001234567/kplr001234567-2010009091648_INJECTED-inj1_llc.fits"
...@@ -113,7 +114,7 @@ class KeplerIoTest(absltest.TestCase): ...@@ -113,7 +114,7 @@ class KeplerIoTest(absltest.TestCase):
long_cadence=False, long_cadence=False,
quarters=[0, 1], quarters=[0, 1],
check_existence=False) check_existence=False)
self.assertItemsEqual([ self.assertCountEqual([
"/my/dir/0012/001234567/kplr001234567-2009131110544_slc.fits", "/my/dir/0012/001234567/kplr001234567-2009131110544_slc.fits",
"/my/dir/0012/001234567/kplr001234567-2009166044711_slc.fits" "/my/dir/0012/001234567/kplr001234567-2009166044711_slc.fits"
], filenames) ], filenames)
...@@ -126,7 +127,7 @@ class KeplerIoTest(absltest.TestCase): ...@@ -126,7 +127,7 @@ class KeplerIoTest(absltest.TestCase):
"0114/011442793/kplr011442793-{}_llc.fits".format(q)) "0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"] for q in ["2009350155506", "2010009091648", "2010174085026"]
] ]
self.assertItemsEqual(expected_filenames, filenames) self.assertCountEqual(expected_filenames, filenames)
def testReadKeplerLightCurve(self): def testReadKeplerLightCurve(self):
filenames = [ filenames = [
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from light_curve_util import median_filter from light_curve import median_filter
class MedianFilterTest(absltest.TestCase): class MedianFilterTest(absltest.TestCase):
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
from light_curve_util.periodic_event import Event from light_curve.periodic_event import Event
class EventTest(absltest.TestCase): class EventTest(absltest.TestCase):
......
...@@ -21,9 +21,9 @@ from __future__ import print_function ...@@ -21,9 +21,9 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from light_curve_util import periodic_event from light_curve import periodic_event
from light_curve_util import util from light_curve import util
class LightCurveUtilTest(absltest.TestCase): class LightCurveUtilTest(absltest.TestCase):
...@@ -89,13 +89,13 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -89,13 +89,13 @@ class LightCurveUtilTest(absltest.TestCase):
] ]
all_flux = [np.ones(25), np.ones(10)] all_flux = [np.ones(25), np.ones(10)]
self.assertEqual(len(all_time), 2) self.assertLen(all_time, 2)
self.assertEqual(len(all_time[0]), 25) self.assertLen(all_time[0], 25)
self.assertEqual(len(all_time[1]), 10) self.assertLen(all_time[1], 10)
self.assertEqual(len(all_flux), 2) self.assertLen(all_flux, 2)
self.assertEqual(len(all_flux[0]), 25) self.assertLen(all_flux[0], 25)
self.assertEqual(len(all_flux[1]), 10) self.assertLen(all_flux[1], 10)
# Gap width 0.5. # Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5) split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
...@@ -268,7 +268,7 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -268,7 +268,7 @@ class LightCurveUtilTest(absltest.TestCase):
np.array([80, 90]), np.array([80, 90]),
] ]
reshard_xs = util.reshard_arrays(xs, ys) reshard_xs = util.reshard_arrays(xs, ys)
self.assertEqual(5, len(reshard_xs)) self.assertLen(reshard_xs, 5)
np.testing.assert_array_equal([], reshard_xs[0]) np.testing.assert_array_equal([], reshard_xs[0])
np.testing.assert_array_equal([1, 2], reshard_xs[1]) np.testing.assert_array_equal([1, 2], reshard_xs[1])
np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2]) np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2])
......
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from astronet.util import config_util from tf_util import config_util
class ConfigUtilTest(tf.test.TestCase): class ConfigUtilTest(tf.test.TestCase):
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import absltest from absl.testing import absltest
from astronet.util import configdict from tf_util import configdict
class ConfigDictTest(absltest.TestCase): class ConfigDictTest(absltest.TestCase):
......
...@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args): ...@@ -45,6 +45,8 @@ def evaluate(estimator, eval_args):
latest_checkpoint = estimator.latest_checkpoint() latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint: if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint. # This is expected if the training job has not yet saved a checkpoint.
tf.logging.info("No checkpoint in %s, skipping evaluation.",
estimator.model_dir)
return global_step, values return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint) tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
...@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args): ...@@ -54,7 +56,7 @@ def evaluate(estimator, eval_args):
input_fn, steps=eval_steps, name=eval_name) input_fn, steps=eval_steps, name=eval_name)
if global_step is None: if global_step is None:
global_step = values[eval_name].get("global_step") global_step = values[eval_name].get("global_step")
except (tf.errors.NotFoundError, ValueError): except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the # Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this # trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases. # in some cases.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment