"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "60b4db63890d84e1f6711e037455f76b2228fc54"
Commit 87820577 authored by Alex Tamkin's avatar Alex Tamkin Committed by Christopher Shallue
Browse files

Ensure consistent argument order, change NaN processing to remove NaN times as...

Ensure consistent argument order, change NaN processing to remove NaN times as well, and do so before scrambling

PiperOrigin-RevId: 207804986
parent e8742a3d
...@@ -150,14 +150,14 @@ def kepler_filenames(base_dir, ...@@ -150,14 +150,14 @@ def kepler_filenames(base_dir,
return filenames return filenames
def scramble_light_curve(all_flux, all_time, all_quarters, scramble_type): def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
"""Scrambles a light curve according to a given scrambling procedure. """Scrambles a light curve according to a given scrambling procedure.
Args: Args:
all_flux: List holding lists of flux values (each interior list holds a
quarter of flux data).
all_time: List holding lists of time values (each interior list holds a all_time: List holding lists of time values (each interior list holds a
quarter of time data). quarter of time data).
all_flux: List holding lists of flux values (each interior list holds a
quarter of flux data).
all_quarters: List of integers specifying which quarters were present in all_quarters: List of integers specifying which quarters were present in
the light curve (max is 18: Q0...Q17). the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2', scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
...@@ -178,11 +178,10 @@ def scramble_light_curve(all_flux, all_time, all_quarters, scramble_type): ...@@ -178,11 +178,10 @@ def scramble_light_curve(all_flux, all_time, all_quarters, scramble_type):
concat_time = np.concatenate(all_time) concat_time = np.concatenate(all_time)
scr_time = [] scr_time = []
for flux in scr_flux: for flux in scr_flux:
same_len_time_list = list(concat_time[:len(flux)]) time, concat_time = np.split(concat_time, [len(flux)])
scr_time.append(same_len_time_list) scr_time.append(time)
concat_time = concat_time[len(flux):]
return scr_flux, scr_time return scr_time, scr_flux
def read_kepler_light_curve(filenames, def read_kepler_light_curve(filenames,
...@@ -212,20 +211,20 @@ def read_kepler_light_curve(filenames, ...@@ -212,20 +211,20 @@ def read_kepler_light_curve(filenames,
flux = light_curve.PDCSAP_FLUX flux = light_curve.PDCSAP_FLUX
# Index into primary HDU header and get quarter. # Index into primary HDU header and get quarter.
all_quarters.append(hdu_list[0].header["QUARTER"]) quarter = hdu_list[0].header["QUARTER"]
# Remove NaN flux values.
valid_indices = np.where(np.isfinite(flux))
time = time[valid_indices]
flux = flux[valid_indices]
if time.size: if time.size:
all_time.append(time) all_time.append(time)
all_flux.append(flux) all_flux.append(flux)
all_quarters.append(quarter)
if scramble_type: if scramble_type:
all_flux, all_time = scramble_light_curve(all_flux, all_time, all_quarters, all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters,
scramble_type) scramble_type)
# Remove NaN flux values after potential scrambling.
for i, (flux, time) in enumerate(zip(all_flux, all_time)):
valid_indices = np.where(np.isfinite(flux))
all_time[i] = time[valid_indices]
all_flux[i] = flux[valid_indices]
return all_time, all_flux return all_time, all_flux
...@@ -19,8 +19,10 @@ from __future__ import division ...@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os.path import os.path
from absl import flags from absl import flags
from absl.testing import absltest from absl.testing import absltest
import numpy as np
from light_curve_util import kepler_io from light_curve_util import kepler_io
...@@ -35,21 +37,24 @@ class KeplerIoTest(absltest.TestCase): ...@@ -35,21 +37,24 @@ class KeplerIoTest(absltest.TestCase):
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR) self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testScrambleLightCurve(self): def testScrambleLightCurve(self):
nan = float("nan") all_flux = [[11, 12], [21], [np.nan, np.nan, 33], [41, 42]]
all_flux = [[11, 12], [21], [nan, nan, 33], [41, 42]]
all_time = [[101, 102], [201], [301, 302, 303], [401, 402]] all_time = [[101, 102], [201], [301, 302, 303], [401, 402]]
all_quarters = [3, 4, 7, 14] all_quarters = [3, 4, 7, 14]
scramble_type = "SCR1" # New quarters order will be [14,7,3,4]. scramble_type = "SCR1" # New quarters order will be [14,7,3,4].
scr_flux, scr_time = kepler_io.scramble_light_curve( scr_time, scr_flux = kepler_io.scramble_light_curve(
all_flux, all_time, all_quarters, scramble_type) all_time, all_flux, all_quarters, scramble_type)
# NaNs are not removed in this function. # NaNs are not removed in this function.
gold_flux = [[41, 42], [nan, nan, 33], [11, 12], [21]] gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]]
gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]] gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]]
self.assertEqual(gold_flux, scr_flux) self.assertEqual(len(gold_flux), len(scr_flux))
self.assertEqual(gold_time, scr_time) self.assertEqual(len(gold_time), len(scr_time))
for i in range(len(gold_flux)):
np.testing.assert_array_equal(gold_flux[i], scr_flux[i])
np.testing.assert_array_equal(gold_time[i], scr_time[i])
def testKeplerFilenames(self): def testKeplerFilenames(self):
# All quarters. # All quarters.
...@@ -137,6 +142,30 @@ class KeplerIoTest(absltest.TestCase): ...@@ -137,6 +142,30 @@ class KeplerIoTest(absltest.TestCase):
self.assertLen(all_time[2], 4486) self.assertLen(all_time[2], 4486)
self.assertLen(all_flux[2], 4486) self.assertLen(all_flux[2], 4486)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambled(self):
filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1")
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4486)
self.assertLen(all_flux[0], 4486)
self.assertLen(all_time[1], 4134)
self.assertLen(all_flux[1], 4134)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
if __name__ == "__main__": if __name__ == "__main__":
FLAGS.test_srcdir = "" FLAGS.test_srcdir = ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment