"llm/vscode:/vscode.git/clone" did not exist on "7afb2e125a6047a42909ef41c1c61720766d64b6"
Commit 87820577 authored by Alex Tamkin's avatar Alex Tamkin Committed by Christopher Shallue
Browse files

Ensure consistent argument order, change NaN processing to remove NaN times as...

Ensure consistent argument order, change NaN processing to remove NaN times as well, and do so before scrambling

PiperOrigin-RevId: 207804986
parent e8742a3d
......@@ -150,14 +150,14 @@ def kepler_filenames(base_dir,
return filenames
def scramble_light_curve(all_flux, all_time, all_quarters, scramble_type):
def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
"""Scrambles a light curve according to a given scrambling procedure.
Args:
all_flux: List holding lists of flux values (each interior list holds a
quarter of flux data).
all_time: List holding lists of time values (each interior list holds a
quarter of time data).
all_flux: List holding lists of flux values (each interior list holds a
quarter of flux data).
all_quarters: List of integers specifying which quarters were present in
the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
......@@ -178,11 +178,10 @@ def scramble_light_curve(all_flux, all_time, all_quarters, scramble_type):
concat_time = np.concatenate(all_time)
scr_time = []
for flux in scr_flux:
same_len_time_list = list(concat_time[:len(flux)])
scr_time.append(same_len_time_list)
concat_time = concat_time[len(flux):]
time, concat_time = np.split(concat_time, [len(flux)])
scr_time.append(time)
return scr_flux, scr_time
return scr_time, scr_flux
def read_kepler_light_curve(filenames,
......@@ -212,20 +211,20 @@ def read_kepler_light_curve(filenames,
flux = light_curve.PDCSAP_FLUX
# Index into primary HDU header and get quarter.
all_quarters.append(hdu_list[0].header["QUARTER"])
quarter = hdu_list[0].header["QUARTER"]
# Remove NaN flux values.
valid_indices = np.where(np.isfinite(flux))
time = time[valid_indices]
flux = flux[valid_indices]
if time.size:
all_time.append(time)
all_flux.append(flux)
all_quarters.append(quarter)
if scramble_type:
all_flux, all_time = scramble_light_curve(all_flux, all_time, all_quarters,
all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters,
scramble_type)
# Remove NaN flux values after potential scrambling.
for i, (flux, time) in enumerate(zip(all_flux, all_time)):
valid_indices = np.where(np.isfinite(flux))
all_time[i] = time[valid_indices]
all_flux[i] = flux[valid_indices]
return all_time, all_flux
......@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
import os.path
from absl import flags
from absl.testing import absltest
import numpy as np
from light_curve_util import kepler_io
......@@ -35,21 +37,24 @@ class KeplerIoTest(absltest.TestCase):
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testScrambleLightCurve(self):
nan = float("nan")
all_flux = [[11, 12], [21], [nan, nan, 33], [41, 42]]
all_flux = [[11, 12], [21], [np.nan, np.nan, 33], [41, 42]]
all_time = [[101, 102], [201], [301, 302, 303], [401, 402]]
all_quarters = [3, 4, 7, 14]
scramble_type = "SCR1" # New quarters order will be [14,7,3,4].
scr_flux, scr_time = kepler_io.scramble_light_curve(
all_flux, all_time, all_quarters, scramble_type)
scr_time, scr_flux = kepler_io.scramble_light_curve(
all_time, all_flux, all_quarters, scramble_type)
# NaNs are not removed in this function.
gold_flux = [[41, 42], [nan, nan, 33], [11, 12], [21]]
gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]]
gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]]
self.assertEqual(gold_flux, scr_flux)
self.assertEqual(gold_time, scr_time)
self.assertEqual(len(gold_flux), len(scr_flux))
self.assertEqual(len(gold_time), len(scr_time))
for i in range(len(gold_flux)):
np.testing.assert_array_equal(gold_flux[i], scr_flux[i])
np.testing.assert_array_equal(gold_time[i], scr_time[i])
def testKeplerFilenames(self):
# All quarters.
......@@ -137,6 +142,30 @@ class KeplerIoTest(absltest.TestCase):
self.assertLen(all_time[2], 4486)
self.assertLen(all_flux[2], 4486)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambled(self):
filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1")
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4486)
self.assertLen(all_flux[0], 4486)
self.assertLen(all_time[1], 4134)
self.assertLen(all_flux[1], 4134)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
if __name__ == "__main__":
FLAGS.test_srcdir = ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment