Unverified Commit bf8c050a authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #4614 from cshallue/master

 Modularize light curve and TCE preprocessing functions for easier reuse
parents f3407671 ddee8c74
......@@ -8,10 +8,11 @@ py_binary(
deps = [":preprocess"],
)
py_binary(
py_library(
name = "preprocess",
srcs = ["preprocess.py"],
deps = [
"//astronet/util:example_util",
"//light_curve_util:kepler_io",
"//light_curve_util:median_filter",
"//light_curve_util:util",
......
......@@ -131,25 +131,6 @@ _LABEL_COLUMN = "av_training_set"
_ALLOWED_LABELS = {"PC", "AFP", "NTP"}
def _set_float_feature(ex, name, value):
"""Sets the value of a float feature in a tensorflow.train.Example proto."""
assert name not in ex.features.feature, "Duplicate feature: %s" % name
ex.features.feature[name].float_list.value.extend([float(v) for v in value])
def _set_bytes_feature(ex, name, value):
"""Sets the value of a bytes feature in a tensorflow.train.Example proto."""
assert name not in ex.features.feature, "Duplicate feature: %s" % name
ex.features.feature[name].bytes_list.value.extend([
str(v).encode("latin-1") for v in value])
def _set_int64_feature(ex, name, value):
"""Sets the value of an int64 feature in a tensorflow.train.Example proto."""
assert name not in ex.features.feature, "Duplicate feature: %s" % name
ex.features.feature[name].int64_list.value.extend([int(v) for v in value])
def _process_tce(tce):
"""Processes the light curve for a Kepler TCE and returns an Example proto.
......@@ -158,39 +139,11 @@ def _process_tce(tce):
Returns:
A tensorflow.train.Example proto containing TCE features.
Raises:
IOError: If the light curve files for this Kepler ID cannot be found.
"""
# Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(tce.kepid,
all_time, all_flux = preprocess.read_light_curve(tce.kepid,
FLAGS.kepler_data_dir)
time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, tce.tce_period, tce.tce_time0bk)
# Generate the local and global views.
global_view = preprocess.global_view(time, flux, tce.tce_period)
local_view = preprocess.local_view(time, flux, tce.tce_period,
tce.tce_duration)
# Make output proto.
ex = tf.train.Example()
# Set time series features.
_set_float_feature(ex, "global_view", global_view)
_set_float_feature(ex, "local_view", local_view)
# Set other columns.
for col_name, value in tce.items():
if np.issubdtype(type(value), np.integer):
_set_int64_feature(ex, col_name, [value])
else:
try:
_set_float_feature(ex, col_name, [float(value)])
except ValueError:
_set_bytes_feature(ex, col_name, [value])
return ex
time, flux = preprocess.process_light_curve(all_time, all_flux)
return preprocess.generate_example_for_tce(time, flux, tce)
def _process_file_shard(tce_table, file_name):
......
......@@ -21,29 +21,28 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import example_util
from light_curve_util import kepler_io
from light_curve_util import median_filter
from light_curve_util import util
from third_party.kepler_spline import kepler_spline
def read_and_process_light_curve(kepid, kepler_data_dir, max_gap_width=0.75):
"""Reads a light curve, fits a B-spline and divides the curve by the spline.
def read_light_curve(kepid, kepler_data_dir):
"""Reads a Kepler light curve.
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
max_gap_width: Gap size (in days) above which the light curve is split for
the fitting of B-splines.
Returns:
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
Raises:
IOError: If the light curve files for this Kepler ID cannot be found.
ValueError: If the spline could not be fit.
"""
# Read the Kepler light curve.
file_names = kepler_io.kepler_filenames(kepler_data_dir, kepid)
......@@ -51,21 +50,26 @@ def read_and_process_light_curve(kepid, kepler_data_dir, max_gap_width=0.75):
raise IOError("Failed to find .fits files in %s for Kepler ID %s" %
(kepler_data_dir, kepid))
all_time, all_flux = kepler_io.read_kepler_light_curve(file_names)
return kepler_io.read_kepler_light_curve(file_names)
# Split on gaps.
all_time, all_flux = util.split(all_time, all_flux, gap_width=max_gap_width)
# Logarithmically sample candidate break point spacings between 0.5 and 20
# days.
bkspaces = np.logspace(np.log10(0.5), np.log10(20), num=20)
def process_light_curve(all_time, all_flux):
"""Removes low-frequency variability from a light curve.
# Generate spline.
spline = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)[0]
Args:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
Returns:
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
"""
# Split on gaps.
all_time, all_flux = util.split(all_time, all_flux, gap_width=0.75)
if spline is None:
raise ValueError("Failed to fit spline with Kepler ID %s", kepid)
# Fit a piecewise-cubic spline with default arguments.
spline = kepler_spline.fit_kepler_spline(all_time, all_flux, verbose=False)[0]
# Concatenate the piecewise light curve and spline.
time = np.concatenate(all_time)
......@@ -77,7 +81,6 @@ def read_and_process_light_curve(kepid, kepler_data_dir, max_gap_width=0.75):
# there. Instead we just remove them.
finite_i = np.isfinite(spline)
if not np.all(finite_i):
tf.logging.warn("Incomplete spline with Kepler ID %s", kepid)
time = time[finite_i]
flux = flux[finite_i]
spline = spline[finite_i]
......@@ -202,3 +205,38 @@ def local_view(time,
bin_width=duration * bin_width_factor,
t_min=max(-period / 2, -duration * num_durations),
t_max=min(period / 2, duration * num_durations))
def generate_example_for_tce(time, flux, tce):
"""Generates a tf.train.Example representing an input TCE.
Args:
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output.
Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all
values present in `tce`.
"""
period = tce["tce_period"]
duration = tce["tce_duration"]
t0 = tce["tce_time0bk"]
time, flux = phase_fold_and_sort_light_curve(time, flux, period, t0)
# Make output proto.
ex = tf.train.Example()
# Set time series features.
example_util.set_float_feature(ex, "global_view",
global_view(time, flux, period))
example_util.set_float_feature(ex, "local_view",
local_view(time, flux, period, duration))
# Set other features in `tce`.
for name, value in tce.items():
example_util.set_feature(ex, name, [value])
return ex
......@@ -42,3 +42,17 @@ py_library(
"//astronet/ops:training",
],
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for getting and setting values in tf.Example protocol buffers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def get_feature(ex, name, kind=None, strict=True):
"""Gets a feature value from a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
A numpy array containing to the values of the specified feature.
Raises:
KeyError: If there is no feature with the specified name.
TypeError: If the feature has a different type to that specified.
"""
if name not in ex.features.feature:
if strict:
raise KeyError(name)
return np.array([])
inferred_kind = ex.features.feature[name].WhichOneof("kind")
if not inferred_kind:
return np.array([]) # Feature exists, but it's empty.
if kind and kind != inferred_kind:
raise TypeError("Requested %s, but Feature has %s" % (kind, inferred_kind))
return np.array(getattr(ex.features.feature[name], inferred_kind).value)
def get_bytes_feature(ex, name, strict=True):
"""Gets the value of a bytes feature from a tf.train.Example."""
return get_feature(ex, name, "bytes_list", strict)
def get_float_feature(ex, name, strict=True):
"""Gets the value of a float feature from a tf.train.Example."""
return get_feature(ex, name, "float_list", strict)
def get_int64_feature(ex, name, strict=True):
"""Gets the value of an int64 feature from a tf.train.Example."""
return get_feature(ex, name, "int64_list", strict)
def _infer_kind(value):
"""Infers the tf.train.Feature kind from a value."""
if np.issubdtype(type(value[0]), np.integer):
return "int64_list"
try:
float(value[0])
return "float_list"
except ValueError:
return "bytes_list"
def set_feature(ex, name, value, kind=None, allow_overwrite=False):
"""Sets a feature value in a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or
if `kind` is unrecognized.
"""
if name in ex.features.feature:
if allow_overwrite:
del ex.features.feature[name]
else:
raise ValueError(
"Attempting to set duplicate feature with name: %s" % name)
if not kind:
kind = _infer_kind(value)
if kind == "bytes_list":
value = [str(v).encode("latin-1") for v in value]
elif kind == "float_list":
value = [float(v) for v in value]
elif kind == "int64_list":
value = [int(v) for v in value]
else:
raise ValueError("Unrecognized kind: %s" % kind)
getattr(ex.features.feature[name], kind).value.extend(value)
def set_float_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of a float feature in a tf.train.Example."""
set_feature(ex, name, value, "float_list", allow_overwrite)
def set_bytes_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of a bytes feature in a tf.train.Example."""
set_feature(ex, name, value, "bytes_list", allow_overwrite)
def set_int64_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of an int64 feature in a tf.train.Example."""
set_feature(ex, name, value, "int64_list", allow_overwrite)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for example_util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import example_util
class ExampleUtilTest(tf.test.TestCase):
def test_get_feature(self):
# Create Example.
bytes_list = tf.train.BytesList(
value=[v.encode("latin-1") for v in ["a", "b", "c"]])
float_list = tf.train.FloatList(value=[1.0, 2.0, 3.0])
int64_list = tf.train.Int64List(value=[11, 22, 33])
ex = tf.train.Example(
features=tf.train.Features(
feature={
"a_bytes": tf.train.Feature(bytes_list=bytes_list),
"b_float": tf.train.Feature(float_list=float_list),
"c_int64": tf.train.Feature(int64_list=int64_list),
"d_empty": tf.train.Feature(),
}))
# Get bytes feature.
np.testing.assert_array_equal(
example_util.get_feature(ex, "a_bytes").astype(str), ["a", "b", "c"])
np.testing.assert_array_equal(
example_util.get_feature(ex, "a_bytes", "bytes_list").astype(str),
["a", "b", "c"])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "a_bytes").astype(str),
["a", "b", "c"])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "a_bytes", "float_list")
with self.assertRaises(TypeError):
example_util.get_float_feature(ex, "a_bytes")
with self.assertRaises(TypeError):
example_util.get_int64_feature(ex, "a_bytes")
# Get float feature.
np.testing.assert_array_almost_equal(
example_util.get_feature(ex, "b_float"), [1.0, 2.0, 3.0])
np.testing.assert_array_almost_equal(
example_util.get_feature(ex, "b_float", "float_list"), [1.0, 2.0, 3.0])
np.testing.assert_array_almost_equal(
example_util.get_float_feature(ex, "b_float"), [1.0, 2.0, 3.0])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "b_float", "int64_list")
with self.assertRaises(TypeError):
example_util.get_bytes_feature(ex, "b_float")
with self.assertRaises(TypeError):
example_util.get_int64_feature(ex, "b_float")
# Get int64 feature.
np.testing.assert_array_equal(
example_util.get_feature(ex, "c_int64"), [11, 22, 33])
np.testing.assert_array_equal(
example_util.get_feature(ex, "c_int64", "int64_list"), [11, 22, 33])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "c_int64"), [11, 22, 33])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "c_int64", "bytes_list")
with self.assertRaises(TypeError):
example_util.get_bytes_feature(ex, "c_int64")
with self.assertRaises(TypeError):
example_util.get_float_feature(ex, "c_int64")
# Get empty feature.
np.testing.assert_array_equal(example_util.get_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_feature(ex, "d_empty", "float_list"), [])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_float_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "d_empty"), [])
# Get nonexistent feature.
with self.assertRaises(KeyError):
example_util.get_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_feature(ex, "nonexistent", "bytes_list")
with self.assertRaises(KeyError):
example_util.get_bytes_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_float_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_int64_feature(ex, "nonexistent")
np.testing.assert_array_equal(
example_util.get_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_float_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "nonexistent", strict=False), [])
def test_set_feature(self):
ex = tf.train.Example()
# Set bytes features.
example_util.set_feature(ex, "a1_bytes", ["a", "b"])
example_util.set_feature(ex, "a2_bytes", ["A", "B"], kind="bytes_list")
example_util.set_bytes_feature(ex, "a3_bytes", ["x", "y"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a1_bytes"].bytes_list.value).astype(str),
["a", "b"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a2_bytes"].bytes_list.value).astype(str),
["A", "B"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a3_bytes"].bytes_list.value).astype(str),
["x", "y"])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "a3_bytes", ["xxx"]) # Duplicate.
# Set float features.
example_util.set_feature(ex, "b1_float", [1.0, 2.0])
example_util.set_feature(ex, "b2_float", [10.0, 20.0], kind="float_list")
example_util.set_float_feature(ex, "b3_float", [88.0, 99.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b1_float"].float_list.value, [1.0, 2.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b2_float"].float_list.value, [10.0, 20.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b3_float"].float_list.value, [88.0, 99.0])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "b3_float", [1234.0]) # Duplicate.
# Set int64 features.
example_util.set_feature(ex, "c1_int64", [1, 2, 3])
example_util.set_feature(ex, "c2_int64", [11, 22, 33], kind="int64_list")
example_util.set_int64_feature(ex, "c3_int64", [88, 99])
np.testing.assert_array_equal(
ex.features.feature["c1_int64"].int64_list.value, [1, 2, 3])
np.testing.assert_array_equal(
ex.features.feature["c2_int64"].int64_list.value, [11, 22, 33])
np.testing.assert_array_equal(
ex.features.feature["c3_int64"].int64_list.value, [88, 99])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "c3_int64", [1234]) # Duplicate.
# Overwrite features.
example_util.set_feature(ex, "a3_bytes", ["xxx"], allow_overwrite=True)
np.testing.assert_array_equal(
np.array(ex.features.feature["a3_bytes"].bytes_list.value).astype(str),
["xxx"])
example_util.set_feature(ex, "b3_float", [1234.0], allow_overwrite=True)
np.testing.assert_array_almost_equal(
ex.features.feature["b3_float"].float_list.value, [1234.0])
example_util.set_feature(ex, "c3_int64", [1234], allow_overwrite=True)
np.testing.assert_array_equal(
ex.features.feature["c3_int64"].int64_list.value, [1234])
if __name__ == "__main__":
tf.test.main()
......@@ -23,6 +23,7 @@ import os.path
from astropy.io import fits
import numpy as np
from tensorflow import gfile
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
......@@ -135,7 +136,7 @@ def kepler_filenames(base_dir,
cadence_suffix)
filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters.
if not check_existence or os.path.isfile(filename):
if not check_existence or gfile.Exists(filename):
filenames.append(filename)
return filenames
......@@ -160,7 +161,7 @@ def read_kepler_light_curve(filenames,
all_flux = []
for filename in filenames:
with fits.open(open(filename, "rb")) as hdu_list:
with fits.open(gfile.Open(filename, "rb")) as hdu_list:
light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
......
......@@ -78,7 +78,11 @@ def split(all_time, all_flux, gap_width=0.75):
return out_time, out_flux
def remove_events(all_time, all_flux, events, width_factor=1.0):
def remove_events(all_time,
all_flux,
events,
width_factor=1.0,
include_empty_segments=True):
"""Removes events from a light curve.
This function accepts either a single-segment or piecewise-defined light
......@@ -91,6 +95,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
flux values of the corresponding time array.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
include_empty_segments: Whether to include empty segments in the output.
Returns:
output_time: Numpy array or list of numpy arrays; the time arrays with
......@@ -118,7 +123,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
if single_segment:
output_time = time[mask]
output_flux = flux[mask]
else:
elif include_empty_segments or np.any(mask):
output_time.append(time[mask])
output_flux.append(flux[mask])
......
......@@ -152,6 +152,30 @@ class LightCurveUtilTest(absltest.TestCase):
self.assertSequenceAlmostEqual([10, 17, 18], output_time[1])
self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1])
# One segment totally removed with include_empty_segments = True.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=True)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceEqual([], output_time[0])
self.assertSequenceEqual([], output_flux[0])
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[1])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[1])
# One segment totally removed with include_empty_segments = False.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=False)
self.assertLen(output_time, 1)
self.assertLen(output_flux, 1)
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0])
def testInterpolateMaskedSpline(self):
all_time = [
np.arange(0, 10, dtype=np.float),
......
......@@ -159,8 +159,7 @@ def choose_kepler_spline(all_time,
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux (brightness) values of the light
curve.
all_flux: List of 1D numpy arrays; the flux values of the light curve.
bkspaces: List of break-point spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
......@@ -187,7 +186,8 @@ def choose_kepler_spline(all_time,
# model and sigma is the constant standard deviation for all flux values.
# Moreover, we assume that s[i] ~= s[i+1]. Therefore,
# (f[i+1] - f[i]) / sqrt(2) ~ N(0, sigma^2).
scaled_diffs = np.concatenate([np.diff(f) / np.sqrt(2) for f in all_flux])
scaled_diffs = [np.diff(f) / np.sqrt(2) for f in all_flux]
scaled_diffs = np.concatenate(scaled_diffs) if scaled_diffs else np.array([])
if not scaled_diffs.size:
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
......@@ -275,3 +275,48 @@ def choose_kepler_spline(all_time,
]
return best_spline, metadata
def fit_kepler_spline(all_time,
all_flux,
bkspace_min=0.5,
bkspace_max=20,
bkspace_num=20,
maxiter=5,
penalty_coeff=1.0,
verbose=True):
"""Fits a Kepler spline with logarithmically-sampled breakpoint spacings.
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux values of the light curve.
bkspace_min: Minimum breakpoint spacing to try.
bkspace_max: Maximum breakpoint spacing to try.
bkspace_num: Number of breakpoint spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
penalty_coeff: Coefficient of the penalty term for using more parameters in
the Bayesian Information Criterion. Decreasing this value will allow
more parameters to be used (i.e. smaller break-point spacing), and
vice-versa.
verbose: Whether to log individual spline errors. Note that if bkspaces
contains many values (particularly small ones) then this may cause
logging pollution if calling this function for many light curves.
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
metadata: Object containing metadata about the spline fit.
"""
# Logarithmically sample bkspace_num candidate break point spacings between
# bkspace_min and bkspace_max.
bkspaces = np.logspace(
np.log10(bkspace_min), np.log10(bkspace_max), num=bkspace_num)
return choose_kepler_spline(
all_time,
all_flux,
bkspaces,
maxiter=maxiter,
penalty_coeff=penalty_coeff,
verbose=verbose)
......@@ -79,6 +79,19 @@ class KeplerSplineTest(absltest.TestCase):
class ChooseKeplerSplineTest(absltest.TestCase):
def testEmptyInput(self):
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time=[],
all_flux=[],
bkspaces=bkspaces,
penalty_coeff=1.0,
verbose=False)
np.testing.assert_array_equal(spline, [])
np.testing.assert_array_equal(metadata.light_curve_mask, [])
def testNoPoints(self):
all_time = [np.array([])]
all_flux = [np.array([])]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment