"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "01a5d18a537b65a156cfa1a77706693a24c869c1"
Commit 9546b04c authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Add an option to exclude empty segments from the output in util.remove_events().

PiperOrigin-RevId: 201007073
parent f3407671
...@@ -23,6 +23,7 @@ import os.path ...@@ -23,6 +23,7 @@ import os.path
from astropy.io import fits from astropy.io import fits
import numpy as np import numpy as np
from tensorflow import gfile
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes. LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
...@@ -135,7 +136,7 @@ def kepler_filenames(base_dir, ...@@ -135,7 +136,7 @@ def kepler_filenames(base_dir,
cadence_suffix) cadence_suffix)
filename = os.path.join(base_dir, base_name) filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters. # Not all stars have data for all quarters.
if not check_existence or os.path.isfile(filename): if not check_existence or gfile.Exists(filename):
filenames.append(filename) filenames.append(filename)
return filenames return filenames
...@@ -160,7 +161,7 @@ def read_kepler_light_curve(filenames, ...@@ -160,7 +161,7 @@ def read_kepler_light_curve(filenames,
all_flux = [] all_flux = []
for filename in filenames: for filename in filenames:
with fits.open(open(filename, "rb")) as hdu_list: with fits.open(gfile.Open(filename, "rb")) as hdu_list:
light_curve = hdu_list[light_curve_extension].data light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX flux = light_curve.PDCSAP_FLUX
......
...@@ -78,7 +78,11 @@ def split(all_time, all_flux, gap_width=0.75): ...@@ -78,7 +78,11 @@ def split(all_time, all_flux, gap_width=0.75):
return out_time, out_flux return out_time, out_flux
def remove_events(all_time, all_flux, events, width_factor=1.0): def remove_events(all_time,
all_flux,
events,
width_factor=1.0,
include_empty_segments=True):
"""Removes events from a light curve. """Removes events from a light curve.
This function accepts either a single-segment or piecewise-defined light This function accepts either a single-segment or piecewise-defined light
...@@ -91,6 +95,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0): ...@@ -91,6 +95,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
flux values of the corresponding time array. flux values of the corresponding time array.
events: List of Event objects to remove. events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove. width_factor: Fractional multiplier of the duration of each event to remove.
include_empty_segments: Whether to include empty segments in the output.
Returns: Returns:
output_time: Numpy array or list of numpy arrays; the time arrays with output_time: Numpy array or list of numpy arrays; the time arrays with
...@@ -118,7 +123,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0): ...@@ -118,7 +123,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
if single_segment: if single_segment:
output_time = time[mask] output_time = time[mask]
output_flux = flux[mask] output_flux = flux[mask]
else: elif include_empty_segments or np.any(mask):
output_time.append(time[mask]) output_time.append(time[mask])
output_flux.append(flux[mask]) output_flux.append(flux[mask])
......
...@@ -152,6 +152,30 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -152,6 +152,30 @@ class LightCurveUtilTest(absltest.TestCase):
self.assertSequenceAlmostEqual([10, 17, 18], output_time[1]) self.assertSequenceAlmostEqual([10, 17, 18], output_time[1])
self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1]) self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1])
# One segment totally removed with include_empty_segments = True.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=True)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceEqual([], output_time[0])
self.assertSequenceEqual([], output_flux[0])
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[1])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[1])
# One segment totally removed with include_empty_segments = False.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=False)
self.assertLen(output_time, 1)
self.assertLen(output_flux, 1)
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0])
def testInterpolateMaskedSpline(self): def testInterpolateMaskedSpline(self):
all_time = [ all_time = [
np.arange(0, 10, dtype=np.float), np.arange(0, 10, dtype=np.float),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment