"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "1279f2cd0f2639dcad2fb3437fceda7539455ae2"
Commit 120bc915 authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Handle single segments and empty arrays more elegantly.

PiperOrigin-RevId: 199862472
parent 7f6313ce
...@@ -18,8 +18,6 @@ from __future__ import absolute_import ...@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import numpy as np import numpy as np
from six.moves import range # pylint:disable=redefined-builtin from six.moves import range # pylint:disable=redefined-builtin
...@@ -51,21 +49,21 @@ def split(all_time, all_flux, gap_width=0.75): ...@@ -51,21 +49,21 @@ def split(all_time, all_flux, gap_width=0.75):
piecewise defined (e.g. split by quarter breaks or gaps in the in the data). piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args: Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time all_time: Numpy array or sequence of numpy arrays; each is a sequence of
values. time values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
values of the corresponding time array. flux values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split. gap_width: Minimum gap size (in time units) for a split.
Returns: Returns:
out_time: List of numpy arrays; the split time arrays. out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays. out_flux: List of numpy arrays; the split flux arrays.
""" """
all_time = np.array(all_time)
all_flux = np.array(all_flux)
# Handle single-segment inputs. # Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion if all_time.ndim == 1:
# to bool fails if all_time is a numpy array, and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
all_time = [all_time] all_time = [all_time]
all_flux = [all_flux] all_flux = [all_flux]
...@@ -90,10 +88,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0): ...@@ -90,10 +88,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
curve (e.g. one that is split by quarter breaks or gaps in the in the data). curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args: Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time all_time: Numpy array or sequence of numpy arrays; each is a sequence of
values. time values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
values of the corresponding time array. flux values of the corresponding time array.
events: List of Event objects to remove. events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove. width_factor: Fractional multiplier of the duration of each event to remove.
...@@ -103,11 +101,11 @@ def remove_events(all_time, all_flux, events, width_factor=1.0): ...@@ -103,11 +101,11 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
output_flux: Numpy array or list of numpy arrays; the flux arrays with output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed. events removed.
""" """
all_time = np.array(all_time)
all_flux = np.array(all_flux)
# Handle single-segment inputs. # Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion if all_time.ndim == 1:
# to bool fails if all_time is a numpy array and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
all_time = [all_time] all_time = [all_time]
all_flux = [all_flux] all_flux = [all_flux]
single_segment = True single_segment = True
...@@ -150,10 +148,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline): ...@@ -150,10 +148,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
interp_spline = [] interp_spline = []
for time, masked_time, masked_spline in zip( for time, masked_time, masked_spline in zip(
all_time, all_masked_time, all_masked_spline): all_time, all_masked_time, all_masked_spline):
if len(masked_time) > 0: # pylint:disable=g-explicit-length-test if masked_time.size:
interp_spline.append(np.interp(time, masked_time, masked_spline)) interp_spline.append(np.interp(time, masked_time, masked_spline))
else: else:
interp_spline.append(np.full_like(time, np.nan)) interp_spline.append(np.array([np.nan] * len(time)))
return interp_spline return interp_spline
......
...@@ -136,20 +136,23 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -136,20 +136,23 @@ class LightCurveUtilTest(absltest.TestCase):
all_time = [ all_time = [
np.arange(0, 10, dtype=np.float), np.arange(0, 10, dtype=np.float),
np.arange(10, 20, dtype=np.float), np.arange(10, 20, dtype=np.float),
np.arange(20, 30, dtype=np.float),
] ]
all_masked_time = [ all_masked_time = [
np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7 np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7
np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19 np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19
np.array([], dtype=np.float)
] ]
all_masked_spline = [2 * t + 100 for t in all_masked_time] all_masked_spline = [2 * t + 100 for t in all_masked_time]
interp_spline = util.interpolate_masked_spline(all_time, all_masked_time, interp_spline = util.interpolate_masked_spline(all_time, all_masked_time,
all_masked_spline) all_masked_spline)
self.assertLen(interp_spline, 2) self.assertLen(interp_spline, 3)
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(
[100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0]) [100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0])
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1]) [120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
self.assertTrue(np.all(np.isnan(interp_spline[2])))
def testCountTransitPoints(self): def testCountTransitPoints(self):
time = np.concatenate([ time = np.concatenate([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment