Unverified Commit 69b01644 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5546 from cshallue/master

Improvements to AstroNet and add AstroWaveNet
parents 91b2debd 763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic transit inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker
def _prepare_wavenet_inputs(light_curve, mask):
"""Gathers synthetic transits into the format expected by AstroWaveNet."""
return {
"autoregressive_input": tf.expand_dims(light_curve, -1),
"conditioning_stack": tf.expand_dims(mask, -1),
}
class SyntheticTransits(base.DatasetBuilder):
"""Synthetic transit inputs to the AstroWaveNet model."""
@staticmethod
def default_config():
return configdict.ConfigDict({
"period_range": (0.5, 4),
"amplitude_range": (1, 1),
"threshold_ratio_range": (0, 0.99),
"phase_range": (0, 1),
"noise_sd_range": (0.1, 0.1),
"mask_probability": 0.1,
"light_curve_time_range": (0, 100),
"light_curve_num_points": 1000
})
def build(self, batch_size):
transit_maker = synthetic_transit_maker.SyntheticTransitMaker(
period_range=self.config.period_range,
amplitude_range=self.config.amplitude_range,
threshold_ratio_range=self.config.threshold_ratio_range,
phase_range=self.config.phase_range,
noise_sd_range=self.config.noise_sd_range)
t_start, t_end = self.config.light_curve_time_range
time = np.linspace(t_start, t_end, self.config.light_curve_num_points)
dataset = tf.data.Dataset.from_generator(
transit_maker.random_light_curve_generator(
time, mask_prob=self.config.mask_probability),
output_types=(tf.float32, tf.float32),
output_shapes=(tf.TensorShape((self.config.light_curve_num_points,)),
tf.TensorShape((self.config.light_curve_num_points,))))
dataset = dataset.map(_prepare_wavenet_inputs)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(-1)
return dataset
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training and evaluating AstroWaveNet models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
from absl import flags
import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model
from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util
FLAGS = flags.FLAGS
flags.DEFINE_enum("dataset", None,
["synthetic_transits", "kepler_light_curves"],
"Dataset for training and/or evaluation.")
flags.DEFINE_string("model_dir", None, "Base output directory.")
flags.DEFINE_string(
"train_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"training dataset.")
flags.DEFINE_string(
"eval_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"evaluation dataset.")
flags.DEFINE_string("config_name", "base",
"Name of the AstroWaveNet configuration.")
flags.DEFINE_string(
"config_overrides", "{}",
"JSON string or JSON file containing overrides to the base configuration.")
flags.DEFINE_enum("schedule", None,
["train", "train_and_eval", "continuous_eval"],
"Schedule for running the model.")
flags.DEFINE_string("eval_name", "val", "Name of the evaluation task.")
flags.DEFINE_integer("train_steps", None, "Total number of steps for training.")
flags.DEFINE_integer("eval_steps", None, "Number of steps for each evaluation.")
flags.DEFINE_integer(
"local_eval_frequency", 1000,
"The number of training steps in between evaluation runs. Only applies "
"when schedule == 'train_and_eval'.")
flags.DEFINE_integer("save_summary_steps", None,
"The frequency at which to save model summaries.")
flags.DEFINE_integer("save_checkpoints_steps", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("save_checkpoints_secs", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("keep_checkpoint_max", 1,
"The maximum number of model checkpoints to keep.")
# ------------------------------------------------------------------------------
# TPU-only flags
# ------------------------------------------------------------------------------
flags.DEFINE_boolean("use_tpu", False, "Whether to execute on TPU.")
flags.DEFINE_string("master", None, "Address of the TensorFlow TPU master.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of TPU shards.")
flags.DEFINE_integer("tpu_iterations_per_loop", 1000,
"Number of iterations per TPU training loop.")
flags.DEFINE_integer(
"eval_batch_size", None,
"Batch size for TPU evaluation. Defaults to the training batch size.")
def _create_run_config():
"""Creates a TPU RunConfig if FLAGS.use_tpu is True, else a RunConfig."""
session_config = tf.ConfigProto(allow_soft_placement=True)
run_config_kwargs = {
"save_summary_steps": FLAGS.save_summary_steps,
"save_checkpoints_steps": FLAGS.save_checkpoints_steps,
"save_checkpoints_secs": FLAGS.save_checkpoints_secs,
"session_config": session_config,
"keep_checkpoint_max": FLAGS.keep_checkpoint_max
}
if FLAGS.use_tpu:
if not FLAGS.master:
raise ValueError("FLAGS.master must be set for TPUEstimator.")
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.tpu_iterations_per_loop,
num_shards=FLAGS.tpu_num_shards,
per_host_input_for_training=(FLAGS.tpu_num_shards <= 8))
run_config = tf.contrib.tpu.RunConfig(
tpu_config=tpu_config, master=FLAGS.master, **run_config_kwargs)
else:
if FLAGS.master:
raise ValueError("FLAGS.master should only be set for TPUEstimator.")
run_config = tf.estimator.RunConfig(**run_config_kwargs)
return run_config
def _get_file_pattern(mode):
"""Gets the value of the file pattern flag for the specified mode."""
flag_name = ("train_files"
if mode == tf.estimator.ModeKeys.TRAIN else "eval_files")
file_pattern = FLAGS[flag_name].value
if file_pattern is None:
raise ValueError("--{} is required for mode '{}'".format(flag_name, mode))
return file_pattern
def _create_dataset_builder(mode, config_overrides=None):
"""Creates a dataset builder for the input pipeline."""
if FLAGS.dataset == "synthetic_transits":
return synthetic_transits.SyntheticTransits(config_overrides)
file_pattern = _get_file_pattern(mode)
if FLAGS.dataset == "kepler_light_curves":
builder_class = kepler_light_curves.KeplerLightCurves
else:
raise ValueError("Unsupported dataset: {}".format(FLAGS.dataset))
return builder_class(
file_pattern,
mode,
config_overrides=config_overrides,
use_tpu=FLAGS.use_tpu)
def _create_input_fn(mode, config_overrides=None):
"""Creates an Estimator input_fn."""
builder = _create_dataset_builder(mode, config_overrides)
tf.logging.info("Dataset config for mode '%s': %s", mode,
config_util.to_json(builder.config))
return estimator_util.create_input_fn(builder)
def _create_eval_args(config_overrides=None):
"""Builds eval_args for estimator_runner.evaluate()."""
if FLAGS.dataset == "synthetic_transits" and not FLAGS.eval_steps:
raise ValueError("Dataset '{}' requires --eval_steps for evaluation".format(
FLAGS.dataset))
input_fn = _create_input_fn(tf.estimator.ModeKeys.EVAL, config_overrides)
return {FLAGS.eval_name: (input_fn, FLAGS.eval_steps)}
def main(argv):
del argv # Unused.
config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name))
config_overrides = json.loads(FLAGS.config_overrides)
for key in config_overrides:
if key not in ["dataset", "hparams"]:
raise ValueError("Unrecognized config override: {}".format(key))
config.hparams.update(config_overrides.get("hparams", {}))
# Log configs.
configs_json = [
("config_overrides", config_util.to_json(config_overrides)),
("config", config_util.to_json(config)),
]
for config_name, config_json in configs_json:
tf.logging.info("%s: %s", config_name, config_json)
# Create the estimator.
run_config = _create_run_config()
estimator = estimator_util.create_estimator(
astrowavenet_model.AstroWaveNet, config.hparams, run_config,
FLAGS.model_dir, FLAGS.eval_batch_size)
if FLAGS.schedule in ["train", "train_and_eval"]:
# Save configs.
tf.gfile.MakeDirs(FLAGS.model_dir)
for config_name, config_json in configs_json:
filename = os.path.join(FLAGS.model_dir, "{}.json".format(config_name))
with tf.gfile.Open(filename, "w") as f:
f.write(config_json)
train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN,
config_overrides.get("dataset"))
train_hooks = []
if FLAGS.schedule == "train":
estimator.train(
train_input_fn, hooks=train_hooks, max_steps=FLAGS.train_steps)
else:
assert FLAGS.schedule == "train_and_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_args=eval_args,
local_eval_frequency=FLAGS.local_eval_frequency,
train_hooks=train_hooks,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# FLAGS.local_eval_frequency. It also saves and logs them, so we don't
# do anything here.
pass
else:
assert FLAGS.schedule == "continuous_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_eval(
estimator=estimator, eval_args=eval_args,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# checkpoint. It also saves and logs them, so we don't do anything here.
pass
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
flags.mark_flags_as_required(["dataset", "model_dir", "schedule"])
def _validate_schedule(flag_values):
"""Validates the --schedule flag and the flags it interacts with."""
schedule = flag_values["schedule"]
save_checkpoints_steps = flag_values["save_checkpoints_steps"]
save_checkpoints_secs = flag_values["save_checkpoints_secs"]
if schedule in ["train", "train_and_eval"]:
if not (save_checkpoints_steps or save_checkpoints_secs):
raise flags.ValidationError(
"--schedule='%s' requires --save_checkpoints_steps or "
"--save_checkpoints_secs." % schedule)
return True
flags.register_multi_flags_validator(
["schedule", "save_checkpoints_steps", "save_checkpoints_secs"],
_validate_schedule)
tf.app.run()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
srcs_version = "PY2AND3",
deps = ["//astronet/ops:training"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from astronet.ops import training
class _InputFn(object):
"""Class that acts as a callable input function for Estimator train / eval."""
def __init__(self, dataset_builder):
"""Initializes the input function.
Args:
dataset_builder: Instance of DatasetBuilder.
"""
self._builder = dataset_builder
def __call__(self, params):
"""Builds the input pipeline."""
return self._builder.build(batch_size=params["batch_size"])
def create_input_fn(dataset_builder):
"""Creates an input_fn that that builds an input pipeline.
Args:
dataset_builder: Instance of DatasetBuilder.
Returns:
A callable that builds an input pipeline and returns a tf.data.Dataset
object.
"""
return _InputFn(dataset_builder)
class _ModelFn(object):
"""Class that acts as a callable model function for Estimator train / eval."""
def __init__(self, model_class, hparams, use_tpu=False):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: A HParams object containing hyperparameters for building and
training the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self._model_class = model_class
self._base_hparams = hparams
self._use_tpu = use_tpu
def __call__(self, features, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams = copy.deepcopy(self._base_hparams)
if "batch_size" in params:
hparams.batch_size = params["batch_size"]
model = self._model_class(features, hparams, mode)
model.build()
# Possibly create train_op.
use_tpu = self._use_tpu
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step)
optimizer = training.create_optimizer(hparams, learning_rate, use_tpu)
train_op = training.create_train_op(model, optimizer)
if use_tpu:
estimator = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
else:
estimator = tf.estimator.EstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
return estimator
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return _ModelFn(model_class, hparams, use_tpu)
def create_estimator(model_class,
hparams,
run_config=None,
model_dir=None,
eval_batch_size=None):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroWaveNet or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if run_config is None:
run_config = tf.estimator.RunConfig()
else:
run_config = copy.deepcopy(run_config)
if not model_dir and not run_config.model_dir:
raise ValueError(
"model_dir must be passed explicitly or specified in run_config")
use_tpu = isinstance(run_config, tf.contrib.tpu.RunConfig)
model_fn = create_model_fn(model_class, hparams, use_tpu)
if use_tpu:
eval_batch_size = eval_batch_size or hparams.batch_size
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
train_batch_size=hparams.batch_size,
eval_batch_size=eval_batch_size)
else:
if eval_batch_size is not None:
raise ValueError("eval_batch_size can only be specified for TPU.")
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
params={"batch_size": hparams.batch_size})
return estimator
......@@ -6,6 +6,7 @@ py_library(
name = "kepler_io",
srcs = ["kepler_io.py"],
srcs_version = "PY2AND3",
deps = [":util"],
)
py_test(
......
......@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -24,7 +24,8 @@ def ValueErrorOnFalse(ok, *output_args):
"""Raises ValueError if not ok, otherwise returns the output arguments."""
n_outputs = len(output_args)
if n_outputs < 2:
raise ValueError("Expected 2 or more output_args. Got: %d" % n_outputs)
raise ValueError(
"Expected 2 or more output_args. Got: {}".format(n_outputs))
if not ok:
error = output_args[-1]
......
......@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -23,10 +23,9 @@ import os.path
from astropy.io import fits
import numpy as np
from light_curve_util import util
from tensorflow import gfile
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
# Quarter index to filename prefix for long cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
LONG_CADENCE_QUARTER_PREFIXES = {
......@@ -73,6 +72,14 @@ SHORT_CADENCE_QUARTER_PREFIXES = {
17: ["2013121191144", "2013131215648"]
}
# Quarter order for different scrambling procedures.
# Page 9: https://ntrs.nasa.gov/archive/nasa/casi.ntrs.nasa.gov/20170009549.pdf.
SIMULATED_DATA_SCRAMBLE_ORDERS = {
"SCR1": [0, 13, 14, 15, 16, 9, 10, 11, 12, 5, 6, 7, 8, 1, 2, 3, 4, 17],
"SCR2": [0, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 5, 6, 7, 8, 17],
"SCR3": [0, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 17],
}
def kepler_filenames(base_dir,
kep_id,
......@@ -98,21 +105,21 @@ def kepler_filenames(base_dir,
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
"""
# Pad the Kepler id with zeros to length 9.
kep_id = "%.9d" % int(kep_id)
kep_id = "{:09d}".format(int(kep_id))
quarter_prefixes, cadence_suffix = ((LONG_CADENCE_QUARTER_PREFIXES, "llc")
if long_cadence else
......@@ -128,12 +135,11 @@ def kepler_filenames(base_dir,
for quarter in quarters:
for quarter_prefix in quarter_prefixes[quarter]:
if injected_group:
base_name = "kplr%s-%s_INJECTED-%s_%s.fits" % (kep_id, quarter_prefix,
injected_group,
cadence_suffix)
base_name = "kplr{}-{}_INJECTED-{}_{}.fits".format(
kep_id, quarter_prefix, injected_group, cadence_suffix)
else:
base_name = "kplr%s-%s_%s.fits" % (kep_id, quarter_prefix,
cadence_suffix)
base_name = "kplr{}-{}_{}.fits".format(kep_id, quarter_prefix,
cadence_suffix)
filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters.
if not check_existence or gfile.Exists(filename):
......@@ -142,40 +148,86 @@ def kepler_filenames(base_dir,
return filenames
def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
"""Scrambles a light curve according to a given scrambling procedure.
Args:
all_time: List holding arrays of time values, each containing a quarter of
time data.
all_flux: List holding arrays of flux values, each containing a quarter of
flux data.
all_quarters: List of integers specifying which quarters are present in
the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
'SCR3'}.
Returns:
scr_flux: Scrambled flux values; the same list as the input flux in another
order.
scr_time: Time values, re-partitioned to match sizes of the scr_flux lists.
"""
order = SIMULATED_DATA_SCRAMBLE_ORDERS[scramble_type]
scr_flux = []
for quarter in order:
# Ignore missing quarters in the scramble order.
if quarter in all_quarters:
scr_flux.append(all_flux[all_quarters.index(quarter)])
scr_time = util.reshard_arrays(all_time, scr_flux)
return scr_time, scr_flux
def read_kepler_light_curve(filenames,
light_curve_extension="LIGHTCURVE",
invert=False):
scramble_type=None,
interpolate_missing_time=False):
"""Reads time and flux measurements for a Kepler target star.
Args:
filenames: A list of .fits files containing time and flux measurements.
light_curve_extension: Name of the HDU 1 extension containing light curves.
invert: Whether to invert the flux measurements by multiplying by -1.
scramble_type: What scrambling procedure to use: 'SCR1', 'SCR2', or 'SCR3'
(pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf).
interpolate_missing_time: Whether to interpolate missing (NaN) time values.
This should only affect the output if scramble_type is specified (NaN time
values typically come with NaN flux values, which are removed anyway, but
scrambing decouples NaN time values from NaN flux values).
Returns:
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_flux: A list of numpy arrays; the flux values of the light curve.
"""
all_time = []
all_flux = []
all_quarters = []
for filename in filenames:
with fits.open(gfile.Open(filename, "rb")) as hdu_list:
quarter = hdu_list["PRIMARY"].header["QUARTER"]
light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
# Remove NaN flux values.
valid_indices = np.where(np.isfinite(flux))
time = time[valid_indices]
flux = flux[valid_indices]
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
if not time.size:
continue # No data.
# Possibly interpolate missing time values.
if interpolate_missing_time:
time = util.interpolate_missing_time(time, light_curve.CADENCENO)
all_time.append(time)
all_flux.append(flux)
all_quarters.append(quarter)
if invert:
flux *= -1
if scramble_type:
all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters,
scramble_type)
if time.size:
all_time.append(time)
all_flux.append(flux)
# Remove timestamps with NaN time or flux values.
for i, (time, flux) in enumerate(zip(all_time, all_flux)):
flux_and_time_finite = np.logical_and(np.isfinite(flux), np.isfinite(time))
all_time[i] = time[flux_and_time_finite]
all_flux[i] = flux[flux_and_time_finite]
return all_time, all_flux
......@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
import os.path
from absl import flags
from absl.testing import absltest
import numpy as np
from light_curve_util import kepler_io
......@@ -34,6 +36,26 @@ class KeplerIoTest(absltest.TestCase):
def setUp(self):
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testScrambleLightCurve(self):
all_flux = [[11, 12], [21], [np.nan, np.nan, 33], [41, 42]]
all_time = [[101, 102], [201], [301, 302, 303], [401, 402]]
all_quarters = [3, 4, 7, 14]
scramble_type = "SCR1" # New quarters order will be [14,7,3,4].
scr_time, scr_flux = kepler_io.scramble_light_curve(
all_time, all_flux, all_quarters, scramble_type)
# NaNs are not removed in this function.
gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]]
gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]]
self.assertEqual(len(gold_flux), len(scr_flux))
self.assertEqual(len(gold_time), len(scr_time))
for i in range(len(gold_flux)):
np.testing.assert_array_equal(gold_flux[i], scr_flux[i])
np.testing.assert_array_equal(gold_time[i], scr_time[i])
def testKeplerFilenames(self):
# All quarters.
filenames = kepler_io.kepler_filenames(
......@@ -100,15 +122,17 @@ class KeplerIoTest(absltest.TestCase):
filenames = kepler_io.kepler_filenames(
self.data_dir, 11442793, check_existence=True)
expected_filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
self.assertItemsEqual(expected_filenames, filenames)
def testReadKeplerLightCurve(self):
filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(filenames)
self.assertLen(all_time, 3)
......@@ -120,6 +144,55 @@ class KeplerIoTest(absltest.TestCase):
self.assertLen(all_time[2], 4486)
self.assertLen(all_flux[2], 4486)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambled(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1")
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
# Arrays are shorter than above due to separation of time and flux NaNs.
self.assertLen(all_time[0], 4344)
self.assertLen(all_flux[0], 4344)
self.assertLen(all_time[1], 4041)
self.assertLen(all_flux[1], 4041)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambledInterpolateMissingTime(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1", interpolate_missing_time=True)
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4486)
self.assertLen(all_flux[0], 4486)
self.assertLen(all_time[1], 4134)
self.assertLen(all_flux[1], 4134)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
if __name__ == "__main__":
FLAGS.test_srcdir = ""
......
......@@ -32,16 +32,16 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
......@@ -51,35 +51,35 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
ValueError: If an argument has an inappropriate value.
"""
if num_bins < 2:
raise ValueError("num_bins must be at least 2. Got: %d" % num_bins)
raise ValueError("num_bins must be at least 2. Got: {}".format(num_bins))
# Validate the lengths of x and y.
x_len = len(x)
if x_len < 2:
raise ValueError("len(x) must be at least 2. Got: %s" % x_len)
raise ValueError("len(x) must be at least 2. Got: {}".format(x_len))
if x_len != len(y):
raise ValueError("len(x) (got: %d) must equal len(y) (got: %d)" % (x_len,
len(y)))
raise ValueError("len(x) (got: {}) must equal len(y) (got: {})".format(
x_len, len(y)))
# Validate x_min and x_max.
x_min = x_min if x_min is not None else x[0]
x_max = x_max if x_max is not None else x[-1]
if x_min >= x_max:
raise ValueError("x_min (got: %d) must be less than x_max (got: %d)" %
(x_min, x_max))
raise ValueError("x_min (got: {}) must be less than x_max (got: {})".format(
x_min, x_max))
if x_min > x[-1]:
raise ValueError(
"x_min (got: %d) must be less than or equal to the largest value of x "
"(got: %d)" % (x_min, x[-1]))
"x_min (got: {}) must be less than or equal to the largest value of x "
"(got: {})".format(x_min, x[-1]))
# Validate bin_width.
bin_width = bin_width if bin_width is not None else (x_max - x_min) / num_bins
if bin_width <= 0:
raise ValueError("bin_width must be positive. Got: %d" % bin_width)
raise ValueError("bin_width must be positive. Got: {}".format(bin_width))
if bin_width >= x_max - x_min:
raise ValueError(
"bin_width (got: %d) must be less than x_max - x_min (got: %d)" %
(bin_width, x_max - x_min))
"bin_width (got: {}) must be less than x_max - x_min (got: {})".format(
bin_width, x_max - x_min))
bin_spacing = (x_max - x_min - bin_width) / (num_bins - 1)
......
......@@ -124,5 +124,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_array_equal([7, 1, 5, 2, 3], result)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -62,7 +62,7 @@ class Event(object):
other_event: An Event.
period_rtol: Relative tolerance in matching the periods.
t0_durations: Tolerance in matching the t0 values, in units of the other
Event's duration.
Event's duration.
Returns:
True if this Event is the same as other_event, within the given tolerance.
......
This diff is collapsed.
......@@ -17,7 +17,7 @@ def robust_mean(y, cut):
Args:
y: 1D numpy array. Assumed to be normally distributed with outliers.
cut: Points more than this number of standard deviations from the median are
ignored.
ignored.
Returns:
mean: A robust estimate of the mean of y.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment