"vscode:/vscode.git/clone" did not exist on "346d2022970959674a6c4296ed64a78bd0367d7e"
Commit 5dc14555 authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Actually handle single-segment / multi-segment inputs correctly.

PiperOrigin-RevId: 199887920
parent 120bc915
......@@ -59,11 +59,8 @@ def split(all_time, all_flux, gap_width=0.75):
out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
all_time = np.array(all_time)
all_flux = np.array(all_flux)
# Handle single-segment inputs.
if all_time.ndim == 1:
if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
......@@ -101,11 +98,8 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed.
"""
all_time = np.array(all_time)
all_flux = np.array(all_flux)
# Handle single-segment inputs.
if all_time.ndim == 1:
if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
single_segment = True
......
......@@ -65,43 +65,63 @@ class LightCurveUtilTest(absltest.TestCase):
self.assertSequenceAlmostEqual(expected, tfold)
def testSplit(self):
# Single segment.
all_time = np.concatenate([np.arange(0, 1, 0.1), np.arange(1.5, 2, 0.1)])
all_flux = np.ones(15)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 2)
self.assertLen(split_flux, 2)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
# Multi segment.
all_time = [
np.concatenate([
np.arange(0, 1, 0.1),
np.arange(1.5, 2, 0.1),
np.arange(3, 4, 0.1)
])
]),
np.arange(4, 5, 0.1)
]
all_flux = [np.array([1] * 25)]
all_flux = [np.ones(25), np.ones(10)]
self.assertEqual(len(all_time), 2)
self.assertEqual(len(all_time[0]), 25)
self.assertEqual(len(all_time[1]), 10)
self.assertEqual(len(all_flux), 2)
self.assertEqual(len(all_flux[0]), 25)
self.assertEqual(len(all_flux[1]), 10)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 3)
self.assertLen(split_flux, 3)
self.assertSequenceAlmostEqual(
[0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], split_time[0])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[0])
self.assertSequenceAlmostEqual([1.5, 1.6, 1.7, 1.8, 1.9], split_time[1])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1], split_flux[1])
self.assertSequenceAlmostEqual(
[3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], split_time[2])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[2])
self.assertLen(split_time, 4)
self.assertLen(split_flux, 4)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[2])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[3])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[3])
# Gap width 1.0.
split_time, split_flux = util.split(all_time, all_flux, gap_width=1)
self.assertLen(split_time, 2)
self.assertLen(split_flux, 2)
self.assertLen(split_time, 3)
self.assertLen(split_flux, 3)
self.assertSequenceAlmostEqual([
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.5, 1.6, 1.7, 1.8, 1.9
], split_time[0])
self.assertSequenceAlmostEqual(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], split_flux[0])
self.assertSequenceAlmostEqual(
[3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], split_time[1])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[1])
self.assertSequenceAlmostEqual(np.ones(15), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[1])
self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[2])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
def testRemoveEvents(self):
time = np.arange(20, dtype=np.float)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment