Commit 9546b04c authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Add an option to exclude empty segments from the output in util.remove_events().

PiperOrigin-RevId: 201007073
parent f3407671
......@@ -23,6 +23,7 @@ import os.path
from astropy.io import fits
import numpy as np
from tensorflow import gfile
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
......@@ -135,7 +136,7 @@ def kepler_filenames(base_dir,
cadence_suffix)
filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters.
if not check_existence or os.path.isfile(filename):
if not check_existence or gfile.Exists(filename):
filenames.append(filename)
return filenames
......@@ -160,7 +161,7 @@ def read_kepler_light_curve(filenames,
all_flux = []
for filename in filenames:
with fits.open(open(filename, "rb")) as hdu_list:
with fits.open(gfile.Open(filename, "rb")) as hdu_list:
light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
......
......@@ -78,7 +78,11 @@ def split(all_time, all_flux, gap_width=0.75):
return out_time, out_flux
def remove_events(all_time, all_flux, events, width_factor=1.0):
def remove_events(all_time,
all_flux,
events,
width_factor=1.0,
include_empty_segments=True):
"""Removes events from a light curve.
This function accepts either a single-segment or piecewise-defined light
......@@ -91,6 +95,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
flux values of the corresponding time array.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
include_empty_segments: Whether to include empty segments in the output.
Returns:
output_time: Numpy array or list of numpy arrays; the time arrays with
......@@ -118,7 +123,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
if single_segment:
output_time = time[mask]
output_flux = flux[mask]
else:
elif include_empty_segments or np.any(mask):
output_time.append(time[mask])
output_flux.append(flux[mask])
......
......@@ -152,6 +152,30 @@ class LightCurveUtilTest(absltest.TestCase):
self.assertSequenceAlmostEqual([10, 17, 18], output_time[1])
self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1])
# One segment totally removed with include_empty_segments = True.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=True)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceEqual([], output_time[0])
self.assertSequenceEqual([], output_flux[0])
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[1])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[1])
# One segment totally removed with include_empty_segments = False.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=False)
self.assertLen(output_time, 1)
self.assertLen(output_flux, 1)
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0])
def testInterpolateMaskedSpline(self):
all_time = [
np.arange(0, 10, dtype=np.float),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment