Commit f88dd68a authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Add two helper functions to light_curve_util/util.py.

1. reshard_arrays(xs, ys): Reshards arrays in xs to match the lengths of arrays in ys.
2. uniform_cadence_light_curve(): Combines data into a single light curve with uniform cadence numbers.

PiperOrigin-RevId: 211724321
parent 6153a373
......@@ -153,18 +153,18 @@ def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
"""Scrambles a light curve according to a given scrambling procedure.
Args:
all_time: List holding lists of time values (each interior list holds a
quarter of time data).
all_flux: List holding lists of flux values (each interior list holds a
quarter of flux data).
all_quarters: List of integers specifying which quarters were present in
the light curve (max is 18: Q0...Q17).
all_time: List holding arrays of time values, each containing a quarter of
time data.
all_flux: List holding arrays of flux values, each containing a quarter of
flux data.
all_quarters: List of integers specifying which quarters are present in
the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
'SCR3'}.
'SCR3'}.
Returns:
scr_flux: Scrambled flux values; the same list as the input flux in another
order.
order.
scr_time: Time values, re-partitioned to match sizes of the scr_flux lists.
"""
order = SIMULATED_DATA_SCRAMBLE_ORDERS[scramble_type]
......@@ -174,12 +174,7 @@ def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
if quarter in all_quarters:
scr_flux.append(all_flux[all_quarters.index(quarter)])
# Reapportion time lists to match sizes of respective flux lists.
concat_time = np.concatenate(all_time)
scr_time = []
for flux in scr_flux:
time, concat_time = np.split(concat_time, [len(flux)])
scr_time.append(time)
scr_time = util.reshard_arrays(all_time, scr_flux)
return scr_time, scr_flux
......@@ -197,13 +192,12 @@ def read_kepler_light_curve(filenames,
(pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf).
interpolate_missing_time: Whether to interpolate missing (NaN) time values.
This should only affect the output if scramble_type is specified (NaN time
values typically come with NaN flux values, which are removed anyway,
but scrambing decouples NaN time values from NaN flux values).
values typically come with NaN flux values, which are removed anyway, but
scrambing decouples NaN time values from NaN flux values).
Returns:
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_flux: A list of numpy arrays; the flux values of the light curve.
"""
all_time = []
all_flux = []
......@@ -221,8 +215,7 @@ def read_kepler_light_curve(filenames,
# Possibly interpolate missing time values.
if interpolate_missing_time:
cadences = light_curve.CADENCENO
time = util.interpolate_missing_time(time, cadences)
time = util.interpolate_missing_time(time, light_curve.CADENCENO)
all_time.append(time)
all_flux.append(flux)
......
......@@ -131,15 +131,15 @@ def remove_events(all_time,
return output_time, output_flux
def interpolate_missing_time(time, cadences=None, fill_value="extrapolate"):
def interpolate_missing_time(time, cadence_no=None, fill_value="extrapolate"):
"""Interpolates missing (NaN or Inf) time values.
Args:
time: A numpy array of monotonically increasing values, with missing values
denoted by NaN or Inf.
cadences: Optional numpy array of cadence indices corresponding to the time
values. If not provided, missing time values are assumed to be evenly
spaced between present time values.
cadence_no: Optional numpy array of cadence numbers corresponding to the
time values. If not provided, missing time values are assumed to be
evenly spaced between present time values.
fill_value: Specifies how missing time values should be treated at the
beginning and end of the array. See scipy.interpolate.interp1d.
......@@ -150,8 +150,8 @@ def interpolate_missing_time(time, cadences=None, fill_value="extrapolate"):
Raises:
ValueError: If fewer than 2 values of time are finite.
"""
if cadences is None:
cadences = np.arange(len(time))
if cadence_no is None:
cadence_no = np.arange(len(time))
is_finite = np.isfinite(time)
num_finite = np.sum(is_finite)
......@@ -161,14 +161,14 @@ def interpolate_missing_time(time, cadences=None, fill_value="extrapolate"):
"len(time) = {} with {} finite values.".format(len(time), num_finite))
interpolate_fn = scipy.interpolate.interp1d(
cadences[is_finite],
cadence_no[is_finite],
time[is_finite],
copy=False,
bounds_error=False,
fill_value=fill_value,
assume_sorted=True)
return interpolate_fn(cadences)
return interpolate_fn(cadence_no)
def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
......@@ -195,6 +195,78 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
return interp_spline
def reshard_arrays(xs, ys):
"""Reshards arrays in xs to match the lengths of arrays in ys.
Args:
xs: List of 1d numpy arrays with the same total length as ys.
ys: List of 1d numpy arrays with the same total length as xs.
Returns:
A list of numpy arrays containing the same elements as xs, in the same
order, but with array lengths matching the pairwise array in ys.
Raises:
ValueError: If xs and ys do not have the same total length.
"""
# Compute indices of boundaries between segments of ys, plus the end boundary.
boundaries = np.cumsum([len(y) for y in ys])
concat_x = np.concatenate(xs)
if len(concat_x) != boundaries[-1]:
raise ValueError(
"xs and ys do not have the same total length ({} vs. {}).".format(
len(concat_x), boundaries[-1]))
boundaries = boundaries[:-1] # Remove exclusive end boundary.
return np.split(concat_x, boundaries)
def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux):
"""Combines data into a single light curve with uniform cadence numbers.
Args:
all_cadence_no: A list of numpy arrays; the cadence numbers of the light
curve.
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays; the flux values of the light curve.
Returns:
cadence_no: numpy array; the cadence numbers of the light curve with no
gaps. It starts and ends at the minimum and maximum cadence numbers in the
input light curve, respectively.
time: numpy array; the time values of the light curve. Missing data points
have value zero and correspond to a False value in the mask.
flux: numpy array; the time values of the light curve. Missing data points
have value zero and correspond to a False value in the mask.
mask: Boolean numpy array; False indicates missing data points, where
missing data points are those that have no corresponding cadence number in
the input or those where at least one of the cadence number, time value,
or flux value is NaN/Inf.
Raises:
ValueError: If there are duplicate cadence numbers in the input.
"""
min_cadence_no = np.min([np.min(c) for c in all_cadence_no])
max_cadence_no = np.max([np.max(c) for c in all_cadence_no])
out_cadence_no = np.arange(
min_cadence_no, max_cadence_no + 1, dtype=all_cadence_no[0].dtype)
out_time = np.zeros_like(out_cadence_no, dtype=all_time[0].dtype)
out_flux = np.zeros_like(out_cadence_no, dtype=all_flux[0].dtype)
out_mask = np.zeros_like(out_cadence_no, dtype=np.bool)
for cadence_no, time, flux in zip(all_cadence_no, all_time, all_flux):
for c, t, f in zip(cadence_no, time, flux):
if np.isfinite(c) and np.isfinite(t) and np.isfinite(f):
i = int(c - min_cadence_no)
if out_mask[i]:
raise ValueError("Duplicate cadence number: {}".format(c))
out_time[i] = t
out_flux[i] = f
out_mask[i] = True
return out_cadence_no, out_time, out_flux, out_mask
def count_transit_points(time, event):
"""Computes the number of points in each transit of a given event.
......
......@@ -253,6 +253,65 @@ class LightCurveUtilTest(absltest.TestCase):
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
self.assertTrue(np.all(np.isnan(interp_spline[2])))
def testReshardArrays(self):
xs = [
np.array([1, 2, 3]),
np.array([4]),
np.array([5, 6, 7, 8, 9]),
np.array([]),
]
ys = [
np.array([]),
np.array([10, 20]),
np.array([30, 40, 50, 60]),
np.array([70]),
np.array([80, 90]),
]
reshard_xs = util.reshard_arrays(xs, ys)
self.assertEqual(5, len(reshard_xs))
np.testing.assert_array_equal([], reshard_xs[0])
np.testing.assert_array_equal([1, 2], reshard_xs[1])
np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2])
np.testing.assert_array_equal([7], reshard_xs[3])
np.testing.assert_array_equal([8, 9], reshard_xs[4])
with self.assertRaisesRegexp(ValueError,
"xs and ys do not have the same total length"):
util.reshard_arrays(xs, [np.array([10, 20, 30]), np.array([40, 50])])
def testUniformCadenceLightCurve(self):
all_cadence_no = [
np.array([13]),
np.array([4, 5, 6]),
np.array([8, 9, 11, 12]),
]
all_time = [
np.array([130]),
np.array([40, 50, 60]),
np.array([80, 90, 110, 120]),
]
all_flux = [
np.array([1300]),
np.array([400, 500, 600]),
np.array([800, np.nan, 1100, 1200]),
]
cadence_no, time, flux, mask = util.uniform_cadence_light_curve(
all_cadence_no, all_time, all_flux)
np.testing.assert_array_equal([4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
cadence_no)
np.testing.assert_array_equal([40, 50, 60, 0, 80, 0, 0, 110, 120, 130],
time)
np.testing.assert_array_equal(
[400, 500, 600, 0, 800, 0, 0, 1100, 1200, 1300], flux)
np.testing.assert_array_equal([1, 1, 1, 0, 1, 0, 0, 1, 1, 1], mask)
# Add duplicate cadence number.
all_cadence_no.append(np.array([13, 14, 15]))
all_time.append(np.array([130, 140, 150]))
all_flux.append(np.array([1300, 1400, 1500]))
with self.assertRaisesRegexp(ValueError, "Duplicate cadence number"):
util.uniform_cadence_light_curve(all_cadence_no, all_time, all_flux)
def testCountTransitPoints(self):
time = np.concatenate([
np.arange(0, 10, 0.1, dtype=np.float),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment