Commit 259b7d1d authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Add the option to interpolate missing time values when reading Kepler light curves.

This applies mainly to scrambled data (NaN time values typically come with NaN flux values, which are removed anyway, but scrambing decouples NaN time values from NaN flux values).

PiperOrigin-RevId: 209029696
parent fa7743b7
...@@ -6,6 +6,7 @@ py_library( ...@@ -6,6 +6,7 @@ py_library(
name = "kepler_io", name = "kepler_io",
srcs = ["kepler_io.py"], srcs = ["kepler_io.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [":util"],
) )
py_test( py_test(
......
...@@ -23,6 +23,7 @@ import os.path ...@@ -23,6 +23,7 @@ import os.path
from astropy.io import fits from astropy.io import fits
import numpy as np import numpy as np
from light_curve_util import util
from tensorflow import gfile from tensorflow import gfile
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes. LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
...@@ -155,17 +156,18 @@ def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type): ...@@ -155,17 +156,18 @@ def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
Args: Args:
all_time: List holding lists of time values (each interior list holds a all_time: List holding lists of time values (each interior list holds a
quarter of time data). quarter of time data).
all_flux: List holding lists of flux values (each interior list holds a all_flux: List holding lists of flux values (each interior list holds a
quarter of flux data). quarter of flux data).
all_quarters: List of integers specifying which quarters were present in all_quarters: List of integers specifying which quarters were present in
the light curve (max is 18: Q0...Q17). the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2', scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
'SCR3'}. 'SCR3'}.
Returns: Returns:
scr_flux: scrambled flux values, the same list of lists in another order scr_flux: Scrambled flux values; the same list as the input flux in another
scr_time: time values, re-partitioned to match sizes of the scr_flux lists order.
scr_time: Time values, re-partitioned to match sizes of the scr_flux lists.
""" """
order = SIMULATED_DATA_SCRAMBLE_ORDERS[scramble_type] order = SIMULATED_DATA_SCRAMBLE_ORDERS[scramble_type]
scr_flux = [] scr_flux = []
...@@ -186,7 +188,8 @@ def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type): ...@@ -186,7 +188,8 @@ def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
def read_kepler_light_curve(filenames, def read_kepler_light_curve(filenames,
light_curve_extension="LIGHTCURVE", light_curve_extension="LIGHTCURVE",
scramble_type=None): scramble_type=None,
interpolate_missing_time=False):
"""Reads time and flux measurements for a Kepler target star. """Reads time and flux measurements for a Kepler target star.
Args: Args:
...@@ -194,6 +197,10 @@ def read_kepler_light_curve(filenames, ...@@ -194,6 +197,10 @@ def read_kepler_light_curve(filenames,
light_curve_extension: Name of the HDU 1 extension containing light curves. light_curve_extension: Name of the HDU 1 extension containing light curves.
scramble_type: What scrambling procedure to use: 'SCR1', 'SCR2', or 'SCR3' scramble_type: What scrambling procedure to use: 'SCR1', 'SCR2', or 'SCR3'
(pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf). (pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf).
interpolate_missing_time: Whether to interpolate missing (NaN) time values.
This should only affect the output if scramble_type is specified (NaN time
values typically come with NaN flux values, which are removed anyway,
but scrambing decouples NaN time values from NaN flux values).
Returns: Returns:
all_time: A list of numpy arrays; the time values of the light curve. all_time: A list of numpy arrays; the time values of the light curve.
...@@ -206,17 +213,22 @@ def read_kepler_light_curve(filenames, ...@@ -206,17 +213,22 @@ def read_kepler_light_curve(filenames,
for filename in filenames: for filename in filenames:
with fits.open(gfile.Open(filename, "rb")) as hdu_list: with fits.open(gfile.Open(filename, "rb")) as hdu_list:
quarter = hdu_list["PRIMARY"].header["QUARTER"]
light_curve = hdu_list[light_curve_extension].data light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
# Index into primary HDU header and get quarter. time = light_curve.TIME
quarter = hdu_list[0].header["QUARTER"] flux = light_curve.PDCSAP_FLUX
if not time.size:
continue # No data.
if time.size: # Possibly interpolate missing time values.
all_time.append(time) if interpolate_missing_time:
all_flux.append(flux) cadences = light_curve.CADENCENO
all_quarters.append(quarter) time = util.interpolate_missing_time(time, cadences)
all_time.append(time)
all_flux.append(flux)
all_quarters.append(quarter)
if scramble_type: if scramble_type:
all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters, all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters,
...@@ -225,8 +237,7 @@ def read_kepler_light_curve(filenames, ...@@ -225,8 +237,7 @@ def read_kepler_light_curve(filenames,
# Remove timestamps with NaN time or flux values. # Remove timestamps with NaN time or flux values.
for i, (time, flux) in enumerate(zip(all_time, all_flux)): for i, (time, flux) in enumerate(zip(all_time, all_flux)):
flux_and_time_finite = np.logical_and(np.isfinite(flux), np.isfinite(time)) flux_and_time_finite = np.logical_and(np.isfinite(flux), np.isfinite(time))
valid_indices = np.where(flux_and_time_finite) all_time[i] = time[flux_and_time_finite]
all_time[i] = time[valid_indices] all_flux[i] = flux[flux_and_time_finite]
all_flux[i] = flux[valid_indices]
return all_time, all_flux return all_time, all_flux
...@@ -168,6 +168,27 @@ class KeplerIoTest(absltest.TestCase): ...@@ -168,6 +168,27 @@ class KeplerIoTest(absltest.TestCase):
self.assertTrue(np.isfinite(time).all()) self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all()) self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambledInterpolateMissingTime(self):
filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1", interpolate_missing_time=True)
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4486)
self.assertLen(all_flux[0], 4486)
self.assertLen(all_time[1], 4134)
self.assertLen(all_flux[1], 4134)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
if __name__ == "__main__": if __name__ == "__main__":
FLAGS.test_srcdir = "" FLAGS.test_srcdir = ""
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import scipy.interpolate
from six.moves import range # pylint:disable=redefined-builtin from six.moves import range # pylint:disable=redefined-builtin
...@@ -130,6 +131,46 @@ def remove_events(all_time, ...@@ -130,6 +131,46 @@ def remove_events(all_time,
return output_time, output_flux return output_time, output_flux
def interpolate_missing_time(time, cadences=None, fill_value="extrapolate"):
"""Interpolates missing (NaN or Inf) time values.
Args:
time: A numpy array of monotonically increasing values, with missing values
denoted by NaN or Inf.
cadences: Optional numpy array of cadence indices corresponding to the time
values. If not provided, missing time values are assumed to be evenly
spaced between present time values.
fill_value: Specifies how missing time values should be treated at the
beginning and end of the array. See scipy.interpolate.interp1d.
Returns:
A numpy array of the same length as the input time array, with NaN/Inf
values replaced with interpolated values.
Raises:
ValueError: If fewer than 2 values of time are finite.
"""
if cadences is None:
cadences = np.arange(len(time))
is_finite = np.isfinite(time)
num_finite = np.sum(is_finite)
if num_finite < 2:
raise ValueError(
"Cannot interpolate time with fewer than 2 finite values. Got "
"len(time) = {} with {} finite values.".format(len(time), num_finite))
interpolate_fn = scipy.interpolate.interp1d(
cadences[is_finite],
time[is_finite],
copy=False,
bounds_error=False,
fill_value=fill_value,
assume_sorted=True)
return interpolate_fn(cadences)
def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline): def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
"""Linearly interpolates spline values across masked points. """Linearly interpolates spline values across masked points.
......
...@@ -176,6 +176,61 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -176,6 +176,61 @@ class LightCurveUtilTest(absltest.TestCase):
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0]) self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0]) self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0])
def testInterpolateMissingTime(self):
# Fewer than 2 finite values.
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([5.0]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([5.0, np.nan]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([np.nan, np.nan, np.nan]))
# Small time arrays.
self.assertSequenceAlmostEqual([0.5, 0.6],
util.interpolate_missing_time(
np.array([0.5, 0.6])))
self.assertSequenceAlmostEqual([0.5, 0.6, 0.7],
util.interpolate_missing_time(
np.array([0.5, np.nan, 0.7])))
# Time array of length 20 with some values NaN.
time = np.array([
np.nan, 0.5, 1.0, 1.5, 2.0, 2.5, np.nan, 3.5, 4.0, 4.5, 5.0, np.nan,
np.nan, np.nan, np.nan, 7.5, 8.0, 8.5, np.nan, np.nan
])
interp_time = util.interpolate_missing_time(time)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 9.0, 9.5
], interp_time)
# Fill with 0.0 for missing values at the beginning and end.
interp_time = util.interpolate_missing_time(time, fill_value=0.0)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 0.0, 0.0
], interp_time)
# Interpolate with cadences.
cadences = np.array([
100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
114, 115, 116, 117, 118, 119
])
interp_time = util.interpolate_missing_time(time, cadences)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 9.0, 9.5
], interp_time)
# Interpolate with missing cadences.
time = np.array([0.6, 0.7, np.nan, np.nan, np.nan, 1.3, 1.4, 1.5])
cadences = np.array([106, 107, 108, 109, 110, 113, 114, 115])
interp_time = util.interpolate_missing_time(time, cadences)
self.assertSequenceAlmostEqual([0.6, 0.7, 0.8, 0.9, 1.0, 1.3, 1.4, 1.5],
interp_time)
def testInterpolateMaskedSpline(self): def testInterpolateMaskedSpline(self):
all_time = [ all_time = [
np.arange(0, 10, dtype=np.float), np.arange(0, 10, dtype=np.float),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment