Commit 120bc915 authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Handle single segments and empty arrays more elegantly.

PiperOrigin-RevId: 199862472
parent 7f6313ce
......@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from six.moves import range # pylint:disable=redefined-builtin
......@@ -51,21 +49,21 @@ def split(all_time, all_flux, gap_width=0.75):
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time
values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux
values of the corresponding time array.
all_time: Numpy array or sequence of numpy arrays; each is a sequence of
time values.
all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
flux values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split.
Returns:
out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
all_time = np.array(all_time)
all_flux = np.array(all_flux)
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
# to bool fails if all_time is a numpy array, and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
if all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
......@@ -90,10 +88,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time
values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux
values of the corresponding time array.
all_time: Numpy array or sequence of numpy arrays; each is a sequence of
time values.
all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
flux values of the corresponding time array.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
......@@ -103,11 +101,11 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed.
"""
all_time = np.array(all_time)
all_flux = np.array(all_flux)
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
# to bool fails if all_time is a numpy array and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
if all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
single_segment = True
......@@ -150,10 +148,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
interp_spline = []
for time, masked_time, masked_spline in zip(
all_time, all_masked_time, all_masked_spline):
if len(masked_time) > 0: # pylint:disable=g-explicit-length-test
if masked_time.size:
interp_spline.append(np.interp(time, masked_time, masked_spline))
else:
interp_spline.append(np.full_like(time, np.nan))
interp_spline.append(np.array([np.nan] * len(time)))
return interp_spline
......
......@@ -136,20 +136,23 @@ class LightCurveUtilTest(absltest.TestCase):
all_time = [
np.arange(0, 10, dtype=np.float),
np.arange(10, 20, dtype=np.float),
np.arange(20, 30, dtype=np.float),
]
all_masked_time = [
np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7
np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19
np.array([], dtype=np.float)
]
all_masked_spline = [2 * t + 100 for t in all_masked_time]
interp_spline = util.interpolate_masked_spline(all_time, all_masked_time,
all_masked_spline)
self.assertLen(interp_spline, 2)
self.assertLen(interp_spline, 3)
self.assertSequenceAlmostEqual(
[100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0])
self.assertSequenceAlmostEqual(
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
self.assertTrue(np.all(np.isnan(interp_spline[2])))
def testCountTransitPoints(self):
time = np.concatenate([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment