Commit 65da497f authored by Shining Sun's avatar Shining Sun
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into cifar_keras

parents 93e0022d 7d032ea3
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for reading Kepler data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from astropy.io import fits
import numpy as np
from light_curve_util import util
from tensorflow import gfile
# Quarter index to filename prefix for long cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
LONG_CADENCE_QUARTER_PREFIXES = {
0: ["2009131105131"],
1: ["2009166043257"],
2: ["2009259160929"],
3: ["2009350155506"],
4: ["2010078095331", "2010009091648"],
5: ["2010174085026"],
6: ["2010265121752"],
7: ["2010355172524"],
8: ["2011073133259"],
9: ["2011177032512"],
10: ["2011271113734"],
11: ["2012004120508"],
12: ["2012088054726"],
13: ["2012179063303"],
14: ["2012277125453"],
15: ["2013011073258"],
16: ["2013098041711"],
17: ["2013131215648"]
}
# Quarter index to filename prefix for short cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
SHORT_CADENCE_QUARTER_PREFIXES = {
0: ["2009131110544"],
1: ["2009166044711"],
2: ["2009201121230", "2009231120729", "2009259162342"],
3: ["2009291181958", "2009322144938", "2009350160919"],
4: ["2010009094841", "2010019161129", "2010049094358", "2010078100744"],
5: ["2010111051353", "2010140023957", "2010174090439"],
6: ["2010203174610", "2010234115140", "2010265121752"],
7: ["2010296114515", "2010326094124", "2010355172524"],
8: ["2011024051157", "2011053090032", "2011073133259"],
9: ["2011116030358", "2011145075126", "2011177032512"],
10: ["2011208035123", "2011240104155", "2011271113734"],
11: ["2011303113607", "2011334093404", "2012004120508"],
12: ["2012032013838", "2012060035710", "2012088054726"],
13: ["2012121044856", "2012151031540", "2012179063303"],
14: ["2012211050319", "2012242122129", "2012277125453"],
15: ["2012310112549", "2012341132017", "2013011073258"],
16: ["2013017113907", "2013065031647", "2013098041711"],
17: ["2013121191144", "2013131215648"]
}
# Quarter order for different scrambling procedures.
# Page 9: https://ntrs.nasa.gov/archive/nasa/casi.ntrs.nasa.gov/20170009549.pdf.
SIMULATED_DATA_SCRAMBLE_ORDERS = {
"SCR1": [0, 13, 14, 15, 16, 9, 10, 11, 12, 5, 6, 7, 8, 1, 2, 3, 4, 17],
"SCR2": [0, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 5, 6, 7, 8, 17],
"SCR3": [0, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 17],
}
def kepler_filenames(base_dir,
kep_id,
long_cadence=True,
quarters=None,
injected_group=None,
check_existence=True):
"""Returns the light curve filenames for a Kepler target star.
This function assumes the directory structure of the Mikulski Archive for
Space Telescopes (http://archive.stsci.edu/pub/kepler/lightcurves).
Specifically, the filenames for a particular Kepler target star have the
following format:
${kep_id:0:4}/${kep_id}/kplr${kep_id}-${quarter_prefix}_${type}.fits,
where:
kep_id is the Kepler id left-padded with zeros to length 9;
quarter_prefix is the filename quarter prefix;
type is one of "llc" (long cadence light curve) or "slc" (short cadence
light curve).
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
"""
# Pad the Kepler id with zeros to length 9.
kep_id = "{:09d}".format(int(kep_id))
quarter_prefixes, cadence_suffix = ((LONG_CADENCE_QUARTER_PREFIXES, "llc")
if long_cadence else
(SHORT_CADENCE_QUARTER_PREFIXES, "slc"))
if quarters is None:
quarters = quarter_prefixes.keys()
quarters = sorted(quarters) # Sort quarters chronologically.
filenames = []
base_dir = os.path.join(base_dir, kep_id[0:4], kep_id)
for quarter in quarters:
for quarter_prefix in quarter_prefixes[quarter]:
if injected_group:
base_name = "kplr{}-{}_INJECTED-{}_{}.fits".format(
kep_id, quarter_prefix, injected_group, cadence_suffix)
else:
base_name = "kplr{}-{}_{}.fits".format(kep_id, quarter_prefix,
cadence_suffix)
filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters.
if not check_existence or gfile.Exists(filename):
filenames.append(filename)
return filenames
def scramble_light_curve(all_time, all_flux, all_quarters, scramble_type):
"""Scrambles a light curve according to a given scrambling procedure.
Args:
all_time: List holding arrays of time values, each containing a quarter of
time data.
all_flux: List holding arrays of flux values, each containing a quarter of
flux data.
all_quarters: List of integers specifying which quarters are present in
the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
'SCR3'}.
Returns:
scr_flux: Scrambled flux values; the same list as the input flux in another
order.
scr_time: Time values, re-partitioned to match sizes of the scr_flux lists.
"""
order = SIMULATED_DATA_SCRAMBLE_ORDERS[scramble_type]
scr_flux = []
for quarter in order:
# Ignore missing quarters in the scramble order.
if quarter in all_quarters:
scr_flux.append(all_flux[all_quarters.index(quarter)])
scr_time = util.reshard_arrays(all_time, scr_flux)
return scr_time, scr_flux
def read_kepler_light_curve(filenames,
light_curve_extension="LIGHTCURVE",
scramble_type=None,
interpolate_missing_time=False):
"""Reads time and flux measurements for a Kepler target star.
Args:
filenames: A list of .fits files containing time and flux measurements.
light_curve_extension: Name of the HDU 1 extension containing light curves.
scramble_type: What scrambling procedure to use: 'SCR1', 'SCR2', or 'SCR3'
(pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf).
interpolate_missing_time: Whether to interpolate missing (NaN) time values.
This should only affect the output if scramble_type is specified (NaN time
values typically come with NaN flux values, which are removed anyway, but
scrambing decouples NaN time values from NaN flux values).
Returns:
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays; the flux values of the light curve.
"""
all_time = []
all_flux = []
all_quarters = []
for filename in filenames:
with fits.open(gfile.Open(filename, "rb")) as hdu_list:
quarter = hdu_list["PRIMARY"].header["QUARTER"]
light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
if not time.size:
continue # No data.
# Possibly interpolate missing time values.
if interpolate_missing_time:
time = util.interpolate_missing_time(time, light_curve.CADENCENO)
all_time.append(time)
all_flux.append(flux)
all_quarters.append(quarter)
if scramble_type:
all_time, all_flux = scramble_light_curve(all_time, all_flux, all_quarters,
scramble_type)
# Remove timestamps with NaN time or flux values.
for i, (time, flux) in enumerate(zip(all_time, all_flux)):
flux_and_time_finite = np.logical_and(np.isfinite(flux), np.isfinite(time))
all_time[i] = time[flux_and_time_finite]
all_flux[i] = flux[flux_and_time_finite]
return all_time, all_flux
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kepler_io.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from absl import flags
from absl.testing import absltest
import numpy as np
from light_curve_util import kepler_io
FLAGS = flags.FLAGS
_DATA_DIR = "light_curve_util/test_data/"
class KeplerIoTest(absltest.TestCase):
def setUp(self):
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testScrambleLightCurve(self):
all_flux = [[11, 12], [21], [np.nan, np.nan, 33], [41, 42]]
all_time = [[101, 102], [201], [301, 302, 303], [401, 402]]
all_quarters = [3, 4, 7, 14]
scramble_type = "SCR1" # New quarters order will be [14,7,3,4].
scr_time, scr_flux = kepler_io.scramble_light_curve(
all_time, all_flux, all_quarters, scramble_type)
# NaNs are not removed in this function.
gold_flux = [[41, 42], [np.nan, np.nan, 33], [11, 12], [21]]
gold_time = [[101, 102], [201, 301, 302], [303, 401], [402]]
self.assertEqual(len(gold_flux), len(scr_flux))
self.assertEqual(len(gold_time), len(scr_time))
for i in range(len(gold_flux)):
np.testing.assert_array_equal(gold_flux[i], scr_flux[i])
np.testing.assert_array_equal(gold_time[i], scr_time[i])
def testKeplerFilenames(self):
# All quarters.
filenames = kepler_io.kepler_filenames(
"/my/dir/", 1234567, check_existence=False)
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009131105131_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009166043257_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009259160929_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010174085026_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010265121752_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010355172524_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2011073133259_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2011177032512_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2011271113734_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012004120508_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012088054726_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012179063303_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012277125453_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2013011073258_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2013098041711_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2013131215648_llc.fits"
], filenames)
# Subset of quarters.
filenames = kepler_io.kepler_filenames(
"/my/dir/", 1234567, quarters=[3, 4], check_existence=False)
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits"
], filenames)
# Injected group.
filenames = kepler_io.kepler_filenames(
"/my/dir/",
1234567,
quarters=[3, 4],
injected_group="inj1",
check_existence=False)
# pylint:disable=line-too-long
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009350155506_INJECTED-inj1_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_INJECTED-inj1_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_INJECTED-inj1_llc.fits"
], filenames)
# pylint:enable=line-too-long
# Short cadence.
filenames = kepler_io.kepler_filenames(
"/my/dir/",
1234567,
long_cadence=False,
quarters=[0, 1],
check_existence=False)
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009131110544_slc.fits",
"/my/dir/0012/001234567/kplr001234567-2009166044711_slc.fits"
], filenames)
# Check existence.
filenames = kepler_io.kepler_filenames(
self.data_dir, 11442793, check_existence=True)
expected_filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
self.assertItemsEqual(expected_filenames, filenames)
def testReadKeplerLightCurve(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(filenames)
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4134)
self.assertLen(all_flux[0], 4134)
self.assertLen(all_time[1], 1008)
self.assertLen(all_flux[1], 1008)
self.assertLen(all_time[2], 4486)
self.assertLen(all_flux[2], 4486)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambled(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1")
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
# Arrays are shorter than above due to separation of time and flux NaNs.
self.assertLen(all_time[0], 4344)
self.assertLen(all_flux[0], 4344)
self.assertLen(all_time[1], 4041)
self.assertLen(all_flux[1], 4041)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
def testReadKeplerLightCurveScrambledInterpolateMissingTime(self):
filenames = [
os.path.join(self.data_dir,
"0114/011442793/kplr011442793-{}_llc.fits".format(q))
for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(
filenames, scramble_type="SCR1", interpolate_missing_time=True)
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4486)
self.assertLen(all_flux[0], 4486)
self.assertLen(all_time[1], 4134)
self.assertLen(all_flux[1], 4134)
self.assertLen(all_time[2], 1008)
self.assertLen(all_flux[2], 1008)
for time, flux in zip(all_time, all_flux):
self.assertTrue(np.isfinite(time).all())
self.assertTrue(np.isfinite(flux).all())
if __name__ == "__main__":
FLAGS.test_srcdir = ""
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility function for smoothing data using a median filter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
"""Computes the median y-value in uniform intervals (bins) along the x-axis.
The interval [x_min, x_max) is divided into num_bins uniformly spaced
intervals of width bin_width. The value computed for each bin is the median
of all y-values whose corresponding x-value is in the interval.
NOTE: x must be sorted in ascending order or the results will be incorrect.
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
spaced bins on the x-axis.
Raises:
ValueError: If an argument has an inappropriate value.
"""
if num_bins < 2:
raise ValueError("num_bins must be at least 2. Got: {}".format(num_bins))
# Validate the lengths of x and y.
x_len = len(x)
if x_len < 2:
raise ValueError("len(x) must be at least 2. Got: {}".format(x_len))
if x_len != len(y):
raise ValueError("len(x) (got: {}) must equal len(y) (got: {})".format(
x_len, len(y)))
# Validate x_min and x_max.
x_min = x_min if x_min is not None else x[0]
x_max = x_max if x_max is not None else x[-1]
if x_min >= x_max:
raise ValueError("x_min (got: {}) must be less than x_max (got: {})".format(
x_min, x_max))
if x_min > x[-1]:
raise ValueError(
"x_min (got: {}) must be less than or equal to the largest value of x "
"(got: {})".format(x_min, x[-1]))
# Validate bin_width.
bin_width = bin_width if bin_width is not None else (x_max - x_min) / num_bins
if bin_width <= 0:
raise ValueError("bin_width must be positive. Got: {}".format(bin_width))
if bin_width >= x_max - x_min:
raise ValueError(
"bin_width (got: {}) must be less than x_max - x_min (got: {})".format(
bin_width, x_max - x_min))
bin_spacing = (x_max - x_min - bin_width) / (num_bins - 1)
# Bins with no y-values will fall back to the global median.
result = np.repeat(np.median(y), num_bins)
# Find the first element of x >= x_min. This loop is guaranteed to produce
# a valid index because we know that x_min <= x[-1].
x_start = 0
while x[x_start] < x_min:
x_start += 1
# The bin at index i is the median of all elements y[j] such that
# bin_min <= x[j] < bin_max, where bin_min and bin_max are the endpoints of
# bin i.
bin_min = x_min # Left endpoint of the current bin.
bin_max = x_min + bin_width # Right endpoint of the current bin.
j_start = x_start # Inclusive left index of the current bin.
j_end = x_start # Exclusive end index of the current bin.
for i in range(num_bins):
# Move j_start to the first index of x >= bin_min.
while j_start < x_len and x[j_start] < bin_min:
j_start += 1
# Move j_end to the first index of x >= bin_max (exclusive end index).
while j_end < x_len and x[j_end] < bin_max:
j_end += 1
if j_end > j_start:
# Compute and insert the median bin value.
result[i] = np.median(y[j_start:j_end])
# Advance the bin.
bin_min += bin_spacing
bin_max += bin_spacing
return result
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for median_filter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util import median_filter
class MedianFilterTest(absltest.TestCase):
def testErrors(self):
# x size less than 2.
x = [1]
y = [2]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
# x and y not the same size.
x = [1, 2]
y = [4, 5, 6]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
# x_min not less than x_max.
x = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=-1, x_max=-1)
# x_min greater than the last element of x.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=0.25, x_min=3.5, x_max=4)
# bin_width nonpositive.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=0, x_min=1, x_max=3)
# bin_width greater than or equal to x_max - x_min.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=1.5, x_max=2.5)
# num_bins less than 2.
x = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=1, bin_width=1, x_min=0, x_max=2)
def testBucketBoundaries(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([2.5, 4.5, 6.5, 8.5, 10.5], result)
def testMultiSizeBins(self):
# Construct bins with size 0, 1, 2, 3, 4, 5, 10, respectively.
x = np.array([
1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6
])
y = np.array([
0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10
])
result = median_filter.median_filter(
x, y, num_bins=7, bin_width=1, x_min=0, x_max=7)
np.testing.assert_array_equal([3, 0, 0, 5, 3, 1, 5.5], result)
def testMedian(self):
x = np.array([-4, -2, -2, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3])
y = np.array([0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([0, 0, 5, 3, 1], result)
def testWideBins(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=6, x_min=-7, x_max=7)
np.testing.assert_array_equal([3, 4.5, 6.5, 8.5, 10.5], result)
def testNarrowBins(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=1, x_min=-4.5, x_max=4.5)
np.testing.assert_array_equal([3, 5, 7, 9, 11], result)
def testEmptyBins(self):
x = np.array([-1, 0, 1])
y = np.array([1, 2, 3])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([2, 2, 1.5, 3, 2], result)
def testDefaultArgs(self):
x = np.array([-4, -2, -2, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3])
y = np.array([7, -1, 3, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1])
result = median_filter.median_filter(x, y, num_bins=5)
np.testing.assert_array_equal([7, 1, 5, 2, 3], result)
if __name__ == "__main__":
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Event class, which represents a periodic event in a light curve."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class Event(object):
"""Represents a periodic event in a light curve."""
def __init__(self, period, duration, t0):
"""Initializes the Event.
Args:
period: Period of the event, in days.
duration: Duration of the event, in days.
t0: Time of the first occurrence of the event, in days.
"""
self._period = period
self._duration = duration
self._t0 = t0
def __str__(self):
return "<period={}, duration={}, t0={}>".format(self.period, self.duration,
self.t0)
def __repr__(self):
return "Event({})".format(str(self))
@property
def period(self):
return self._period
@property
def duration(self):
return self._duration
@property
def t0(self):
return self._t0
def equals(self, other_event, period_rtol=0.001, t0_durations=1):
"""Compares this Event to another Event, within the given tolerance.
Args:
other_event: An Event.
period_rtol: Relative tolerance in matching the periods.
t0_durations: Tolerance in matching the t0 values, in units of the other
Event's duration.
Returns:
True if this Event is the same as other_event, within the given tolerance.
"""
# First compare the periods.
period_match = np.isclose(
self.period, other_event.period, rtol=period_rtol, atol=1e-8)
if not period_match:
return False
# To compare t0, we must consider that self.t0 and other_event.t0 may be at
# different phases. Just comparing mod(self.t0, period) to
# mod(other_event.t0, period) does not work because two similar values could
# end up at different ends of [0, period).
#
# Define t0_diff to be the absolute difference, up to multiples of period.
# This value is always in [0, period/2).
t0_diff = np.mod(self.t0 - other_event.t0, other_event.period)
if t0_diff > other_event.period / 2:
t0_diff = other_event.period - t0_diff
return t0_diff < t0_durations * other_event.duration
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for periodic_event.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from light_curve_util.periodic_event import Event
class EventTest(absltest.TestCase):
def testStr(self):
self.assertEqual(str(Event(1, 2, 3)), "<period=1, duration=2, t0=3>")
def testRepr(self):
self.assertEqual(
repr(Event(1, 2, 3)), "Event(<period=1, duration=2, t0=3>)")
def testEquals(self):
event = Event(period=100, duration=5, t0=2)
# Varying periods.
self.assertFalse(event.equals(Event(period=0, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=50, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=99.89, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=99.91, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100.01, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=101, duration=5, t0=2)))
# Different period tolerance.
self.assertTrue(
event.equals(Event(period=99.1, duration=5, t0=2), period_rtol=0.01))
self.assertTrue(
event.equals(Event(period=100.9, duration=5, t0=2), period_rtol=0.01))
self.assertFalse(
event.equals(Event(period=98.9, duration=5, t0=2), period_rtol=0.01))
self.assertFalse(
event.equals(Event(period=101.1, duration=5, t0=2), period_rtol=0.01))
# Varying t0.
self.assertTrue(event.equals(Event(period=100, duration=5, t0=0)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=6.9)))
self.assertFalse(event.equals(Event(period=100, duration=5, t0=7.1)))
# t0 at the other end of [0, period).
self.assertFalse(event.equals(Event(period=100, duration=5, t0=96.9)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=97.1)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=100)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=102)))
self.assertFalse(event.equals(Event(period=100, duration=5, t0=107.1)))
# Varying duration.
self.assertFalse(event.equals(Event(period=100, duration=5, t0=10)))
self.assertFalse(event.equals(Event(period=100, duration=7, t0=10)))
self.assertTrue(event.equals(Event(period=100, duration=9, t0=10)))
# Different duration tolerance.
self.assertFalse(
event.equals(Event(period=100, duration=5, t0=10), t0_durations=1))
self.assertTrue(
event.equals(Event(period=100, duration=5, t0=10), t0_durations=2))
if __name__ == "__main__":
absltest.main()
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Light curve utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import scipy.interpolate
from six.moves import range # pylint:disable=redefined-builtin
def phase_fold_time(time, period, t0):
"""Creates a phase-folded time vector.
result[i] is the unique number in [-period / 2, period / 2)
such that result[i] = time[i] - t0 + k_i * period, for some integer k_i.
Args:
time: 1D numpy array of time values.
period: A positive real scalar; the period to fold over.
t0: The center of the resulting folded vector; this value is mapped to 0.
Returns:
A 1D numpy array.
"""
half_period = period / 2
result = np.mod(time + (half_period - t0), period)
result -= half_period
return result
def split(all_time, all_flux, gap_width=0.75):
"""Splits a light curve on discontinuities (gaps).
This function accepts a light curve that is either a single segment, or is
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or sequence of numpy arrays; each is a sequence of
time values.
all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
flux values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split.
Returns:
out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
# Handle single-segment inputs.
if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
out_time = []
out_flux = []
for time, flux in zip(all_time, all_flux):
start = 0
for end in range(1, len(time) + 1):
# Choose the largest endpoint such that time[start:end] has no gaps.
if end == len(time) or time[end] - time[end - 1] > gap_width:
out_time.append(time[start:end])
out_flux.append(flux[start:end])
start = end
return out_time, out_flux
def remove_events(all_time,
all_flux,
events,
width_factor=1.0,
include_empty_segments=True):
"""Removes events from a light curve.
This function accepts either a single-segment or piecewise-defined light
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or sequence of numpy arrays; each is a sequence of
time values.
all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
flux values of the corresponding time array.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
include_empty_segments: Whether to include empty segments in the output.
Returns:
output_time: Numpy array or list of numpy arrays; the time arrays with
events removed.
output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed.
"""
# Handle single-segment inputs.
if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
single_segment = True
else:
single_segment = False
output_time = []
output_flux = []
for time, flux in zip(all_time, all_flux):
mask = np.ones_like(time, dtype=np.bool)
for event in events:
transit_dist = np.abs(phase_fold_time(time, event.period, event.t0))
mask = np.logical_and(mask,
transit_dist > 0.5 * width_factor * event.duration)
if single_segment:
output_time = time[mask]
output_flux = flux[mask]
elif include_empty_segments or np.any(mask):
output_time.append(time[mask])
output_flux.append(flux[mask])
return output_time, output_flux
def interpolate_missing_time(time, cadence_no=None, fill_value="extrapolate"):
"""Interpolates missing (NaN or Inf) time values.
Args:
time: A numpy array of monotonically increasing values, with missing values
denoted by NaN or Inf.
cadence_no: Optional numpy array of cadence numbers corresponding to the
time values. If not provided, missing time values are assumed to be evenly
spaced between present time values.
fill_value: Specifies how missing time values should be treated at the
beginning and end of the array. See scipy.interpolate.interp1d.
Returns:
A numpy array of the same length as the input time array, with NaN/Inf
values replaced with interpolated values.
Raises:
ValueError: If fewer than 2 values of time are finite.
"""
if cadence_no is None:
cadence_no = np.arange(len(time))
is_finite = np.isfinite(time)
num_finite = np.sum(is_finite)
if num_finite < 2:
raise ValueError(
"Cannot interpolate time with fewer than 2 finite values. Got "
"len(time) = {} with {} finite values.".format(len(time), num_finite))
interpolate_fn = scipy.interpolate.interp1d(
cadence_no[is_finite],
time[is_finite],
copy=False,
bounds_error=False,
fill_value=fill_value,
assume_sorted=True)
return interpolate_fn(cadence_no)
def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
"""Linearly interpolates spline values across masked points.
Args:
all_time: List of numpy arrays; each is a sequence of time values.
all_masked_time: List of numpy arrays; each is a sequence of time values
with some values missing (masked).
all_masked_spline: List of numpy arrays; the masked spline values
corresponding to all_masked_time.
Returns:
interp_spline: List of numpy arrays; each is the masked spline with missing
points linearly interpolated.
"""
interp_spline = []
for time, masked_time, masked_spline in zip(all_time, all_masked_time,
all_masked_spline):
if masked_time.size:
interp_spline.append(np.interp(time, masked_time, masked_spline))
else:
interp_spline.append(np.array([np.nan] * len(time)))
return interp_spline
def reshard_arrays(xs, ys):
"""Reshards arrays in xs to match the lengths of arrays in ys.
Args:
xs: List of 1d numpy arrays with the same total length as ys.
ys: List of 1d numpy arrays with the same total length as xs.
Returns:
A list of numpy arrays containing the same elements as xs, in the same
order, but with array lengths matching the pairwise array in ys.
Raises:
ValueError: If xs and ys do not have the same total length.
"""
# Compute indices of boundaries between segments of ys, plus the end boundary.
boundaries = np.cumsum([len(y) for y in ys])
concat_x = np.concatenate(xs)
if len(concat_x) != boundaries[-1]:
raise ValueError(
"xs and ys do not have the same total length ({} vs. {}).".format(
len(concat_x), boundaries[-1]))
boundaries = boundaries[:-1] # Remove exclusive end boundary.
return np.split(concat_x, boundaries)
def uniform_cadence_light_curve(cadence_no, time, flux):
"""Combines data into a single light curve with uniform cadence numbers.
Args:
cadence_no: numpy array; the cadence numbers of the light curve.
time: numpy array; the time values of the light curve.
flux: numpy array; the flux values of the light curve.
Returns:
cadence_no: numpy array; the cadence numbers of the light curve with no
gaps. It starts and ends at the minimum and maximum cadence numbers in the
input light curve, respectively.
time: numpy array; the time values of the light curve. Missing data points
have value zero and correspond to a False value in the mask.
flux: numpy array; the time values of the light curve. Missing data points
have value zero and correspond to a False value in the mask.
mask: Boolean numpy array; False indicates missing data points, where
missing data points are those that have no corresponding cadence number in
the input or those where at least one of the cadence number, time value,
or flux value is NaN/Inf.
Raises:
ValueError: If there are duplicate cadence numbers in the input.
"""
min_cadence_no = np.min(cadence_no)
max_cadence_no = np.max(cadence_no)
out_cadence_no = np.arange(
min_cadence_no, max_cadence_no + 1, dtype=cadence_no.dtype)
out_time = np.zeros_like(out_cadence_no, dtype=time.dtype)
out_flux = np.zeros_like(out_cadence_no, dtype=flux.dtype)
out_mask = np.zeros_like(out_cadence_no, dtype=np.bool)
for c, t, f in zip(cadence_no, time, flux):
if np.isfinite(c) and np.isfinite(t) and np.isfinite(f):
i = int(c - min_cadence_no)
if out_mask[i]:
raise ValueError("Duplicate cadence number: {}".format(c))
out_time[i] = t
out_flux[i] = f
out_mask[i] = True
return out_cadence_no, out_time, out_flux, out_mask
def count_transit_points(time, event):
"""Computes the number of points in each transit of a given event.
Args:
time: Sorted numpy array of time values.
event: An Event object.
Returns:
A numpy array containing the number of time points "in transit" for each
transit occurring between the first and last time values.
Raises:
ValueError: If there are more than 10**6 transits.
"""
t_min = np.min(time)
t_max = np.max(time)
# Tiny periods or erroneous time values could make this loop take forever.
if (t_max - t_min) / event.period > 10**6:
raise ValueError(
"Too many transits! Time range is [{:.4f}, {:.4f}] and period is "
"{:.4e}.".format(t_min, t_max, event.period))
# Make sure t0 is in [t_min, t_min + period).
t0 = np.mod(event.t0 - t_min, event.period) + t_min
# Prepare loop variables.
points_in_transit = []
i, j = 0, 0
for transit_midpoint in np.arange(t0, t_max, event.period):
transit_begin = transit_midpoint - event.duration / 2
transit_end = transit_midpoint + event.duration / 2
# Move time[i] to the first point >= transit_begin.
while time[i] < transit_begin:
# transit_begin is guaranteed to be < np.max(t) (provided duration >= 0).
# Therefore, i cannot go out of range.
i += 1
# Move time[j] to the first point > transit_end.
while time[j] <= transit_end:
j += 1
# j went out of range. We're finished.
if j >= len(time):
break
# The points in the current transit duration are precisely time[i:j].
# Since j is an exclusive index, there are exactly j-i points in transit.
points_in_transit.append(j - i)
return np.array(points_in_transit)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util import periodic_event
from light_curve_util import util
class LightCurveUtilTest(absltest.TestCase):
def testPhaseFoldTime(self):
time = np.arange(0, 2, 0.1)
# Simple.
tfold = util.phase_fold_time(time, period=1, t0=0.45)
expected = [
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45,
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45
]
self.assertSequenceAlmostEqual(expected, tfold)
# Large t0.
tfold = util.phase_fold_time(time, period=1, t0=1.25)
expected = [
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35, -0.25,
-0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35
]
self.assertSequenceAlmostEqual(expected, tfold)
# Negative t0.
tfold = util.phase_fold_time(time, period=1, t0=-1.65)
expected = [
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35,
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45
]
self.assertSequenceAlmostEqual(expected, tfold)
# Negative time.
time = np.arange(-3, -1, 0.1)
tfold = util.phase_fold_time(time, period=1, t0=0.55)
expected = [
0.45, -0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45,
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35
]
self.assertSequenceAlmostEqual(expected, tfold)
def testSplit(self):
# Single segment.
all_time = np.concatenate([np.arange(0, 1, 0.1), np.arange(1.5, 2, 0.1)])
all_flux = np.ones(15)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 2)
self.assertLen(split_flux, 2)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
# Multi segment.
all_time = [
np.concatenate([
np.arange(0, 1, 0.1),
np.arange(1.5, 2, 0.1),
np.arange(3, 4, 0.1)
]),
np.arange(4, 5, 0.1)
]
all_flux = [np.ones(25), np.ones(10)]
self.assertEqual(len(all_time), 2)
self.assertEqual(len(all_time[0]), 25)
self.assertEqual(len(all_time[1]), 10)
self.assertEqual(len(all_flux), 2)
self.assertEqual(len(all_flux[0]), 25)
self.assertEqual(len(all_flux[1]), 10)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 4)
self.assertLen(split_flux, 4)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[2])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[3])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[3])
# Gap width 1.0.
split_time, split_flux = util.split(all_time, all_flux, gap_width=1)
self.assertLen(split_time, 3)
self.assertLen(split_flux, 3)
self.assertSequenceAlmostEqual([
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.5, 1.6, 1.7, 1.8, 1.9
], split_time[0])
self.assertSequenceAlmostEqual(np.ones(15), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[1])
self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[2])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
def testRemoveEvents(self):
time = np.arange(20, dtype=np.float)
flux = 10 * time
# One event.
events = [periodic_event.Event(period=4, duration=1.5, t0=3.5)]
output_time, output_flux = util.remove_events(time, flux, events)
self.assertSequenceAlmostEqual([1, 2, 5, 6, 9, 10, 13, 14, 17, 18],
output_time)
self.assertSequenceAlmostEqual(
[10, 20, 50, 60, 90, 100, 130, 140, 170, 180], output_flux)
# Two events.
events.append(periodic_event.Event(period=7, duration=1.5, t0=6.5))
output_time, output_flux = util.remove_events(time, flux, events)
self.assertSequenceAlmostEqual([1, 2, 5, 9, 10, 17, 18], output_time)
self.assertSequenceAlmostEqual([10, 20, 50, 90, 100, 170, 180], output_flux)
# Multi segment light curve.
time = [np.arange(10, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
output_time, output_flux = util.remove_events(time, flux, events)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceAlmostEqual([1, 2, 5, 9], output_time[0])
self.assertSequenceAlmostEqual([10, 20, 50, 90], output_flux[0])
self.assertSequenceAlmostEqual([10, 17, 18], output_time[1])
self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1])
# One segment totally removed with include_empty_segments = True.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=True)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceEqual([], output_time[0])
self.assertSequenceEqual([], output_flux[0])
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[1])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[1])
# One segment totally removed with include_empty_segments = False.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=False)
self.assertLen(output_time, 1)
self.assertLen(output_flux, 1)
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0])
def testInterpolateMissingTime(self):
# Fewer than 2 finite values.
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([5.0]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([5.0, np.nan]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([np.nan, np.nan, np.nan]))
# Small time arrays.
self.assertSequenceAlmostEqual([0.5, 0.6],
util.interpolate_missing_time(
np.array([0.5, 0.6])))
self.assertSequenceAlmostEqual([0.5, 0.6, 0.7],
util.interpolate_missing_time(
np.array([0.5, np.nan, 0.7])))
# Time array of length 20 with some values NaN.
time = np.array([
np.nan, 0.5, 1.0, 1.5, 2.0, 2.5, np.nan, 3.5, 4.0, 4.5, 5.0, np.nan,
np.nan, np.nan, np.nan, 7.5, 8.0, 8.5, np.nan, np.nan
])
interp_time = util.interpolate_missing_time(time)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 9.0, 9.5
], interp_time)
# Fill with 0.0 for missing values at the beginning and end.
interp_time = util.interpolate_missing_time(time, fill_value=0.0)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 0.0, 0.0
], interp_time)
# Interpolate with cadences.
cadences = np.array([
100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
114, 115, 116, 117, 118, 119
])
interp_time = util.interpolate_missing_time(time, cadences)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 9.0, 9.5
], interp_time)
# Interpolate with missing cadences.
time = np.array([0.6, 0.7, np.nan, np.nan, np.nan, 1.3, 1.4, 1.5])
cadences = np.array([106, 107, 108, 109, 110, 113, 114, 115])
interp_time = util.interpolate_missing_time(time, cadences)
self.assertSequenceAlmostEqual([0.6, 0.7, 0.8, 0.9, 1.0, 1.3, 1.4, 1.5],
interp_time)
def testInterpolateMaskedSpline(self):
all_time = [
np.arange(0, 10, dtype=np.float),
np.arange(10, 20, dtype=np.float),
np.arange(20, 30, dtype=np.float),
]
all_masked_time = [
np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7
np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19
np.array([], dtype=np.float)
]
all_masked_spline = [2 * t + 100 for t in all_masked_time]
interp_spline = util.interpolate_masked_spline(all_time, all_masked_time,
all_masked_spline)
self.assertLen(interp_spline, 3)
self.assertSequenceAlmostEqual(
[100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0])
self.assertSequenceAlmostEqual(
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
self.assertTrue(np.all(np.isnan(interp_spline[2])))
def testReshardArrays(self):
xs = [
np.array([1, 2, 3]),
np.array([4]),
np.array([5, 6, 7, 8, 9]),
np.array([]),
]
ys = [
np.array([]),
np.array([10, 20]),
np.array([30, 40, 50, 60]),
np.array([70]),
np.array([80, 90]),
]
reshard_xs = util.reshard_arrays(xs, ys)
self.assertEqual(5, len(reshard_xs))
np.testing.assert_array_equal([], reshard_xs[0])
np.testing.assert_array_equal([1, 2], reshard_xs[1])
np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2])
np.testing.assert_array_equal([7], reshard_xs[3])
np.testing.assert_array_equal([8, 9], reshard_xs[4])
with self.assertRaisesRegexp(ValueError,
"xs and ys do not have the same total length"):
util.reshard_arrays(xs, [np.array([10, 20, 30]), np.array([40, 50])])
def testUniformCadenceLightCurve(self):
input_cadence_no = np.array([13, 4, 5, 6, 8, 9, 11, 12])
input_time = np.array([130, 40, 50, 60, 80, 90, 110, 120])
input_flux = np.array([1300, 400, 500, 600, 800, np.nan, 1100, 1200])
cadence_no, time, flux, mask = util.uniform_cadence_light_curve(
input_cadence_no, input_time, input_flux)
np.testing.assert_array_equal([4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
cadence_no)
np.testing.assert_array_equal([40, 50, 60, 0, 80, 0, 0, 110, 120, 130],
time)
np.testing.assert_array_equal(
[400, 500, 600, 0, 800, 0, 0, 1100, 1200, 1300], flux)
np.testing.assert_array_equal([1, 1, 1, 0, 1, 0, 0, 1, 1, 1], mask)
# Add duplicate cadence number.
input_cadence_no = np.concatenate([input_cadence_no, np.array([13, 14])])
input_time = np.concatenate([input_time, np.array([130, 140])])
input_flux = np.concatenate([input_flux, np.array([1300, 1400])])
with self.assertRaisesRegexp(ValueError, "Duplicate cadence number"):
util.uniform_cadence_light_curve(input_cadence_no, input_time, input_flux)
def testCountTransitPoints(self):
time = np.concatenate([
np.arange(0, 10, 0.1, dtype=np.float),
np.arange(15, 30, 0.1, dtype=np.float),
np.arange(50, 100, 0.1, dtype=np.float)
])
event = periodic_event.Event(period=10, duration=5, t0=9.95)
points_in_transit = util.count_transit_points(time, event)
np.testing.assert_array_equal([25, 50, 25, 0, 25, 50, 50, 50, 50],
points_in_transit)
if __name__ == "__main__":
absltest.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # MIT
exports_files(["LICENSE"])
py_library(
name = "kepler_spline",
srcs = ["kepler_spline.py"],
srcs_version = "PY2AND3",
deps = ["//third_party/robust_mean"],
)
py_test(
name = "kepler_spline_test",
size = "small",
srcs = ["kepler_spline_test.py"],
srcs_version = "PY2AND3",
deps = [":kepler_spline"],
)
MIT License
Copyright (c) 2017 avanderburg
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
"""Functions for computing normalization splines for Kepler light curves."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import numpy as np
from pydl.pydlutils import bspline
from third_party.robust_mean import robust_mean
class InsufficientPointsError(Exception):
"""Indicates that insufficient points were available for spline fitting."""
pass
class SplineError(Exception):
"""Indicates an error in the underlying spline-fitting implementation."""
pass
def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
"""Computes a best-fit spline curve for a light curve segment.
The spline is fit using an iterative process to remove outliers that may cause
the spline to be "pulled" by discrepent points. In each iteration the spline
is fit, and if there are any points where the absolute deviation from the
median residual is at least 3*sigma (where sigma is a robust estimate of the
standard deviation of the residuals), those points are removed and the spline
is re-fit.
Args:
time: Numpy array; the time values of the light curve.
flux: Numpy array; the flux (brightness) values of the light curve.
bkspace: Spline break point spacing in time units.
maxiter: Maximum number of attempts to fit the spline after removing badly
fit points.
outlier_cut: The maximum number of standard deviations from the median
spline residual before a point is considered an outlier.
Returns:
spline: The values of the fitted spline corresponding to the input time
values.
mask: Boolean mask indicating the points used to fit the final spline.
Raises:
InsufficientPointsError: If there were insufficient points (after removing
outliers) for spline fitting.
SplineError: If the spline could not be fit, for example if the breakpoint
spacing is too small.
"""
if len(time) < 4:
raise InsufficientPointsError(
"Cannot fit a spline on less than 4 points. Got {} points.".format(
len(time)))
# Rescale time into [0, 1].
t_min = np.min(time)
t_max = np.max(time)
time = (time - t_min) / (t_max - t_min)
bkspace /= (t_max - t_min) # Rescale bucket spacing.
# Values of the best fitting spline evaluated at the time points.
spline = None
# Mask indicating the points used to fit the spline.
mask = None
for _ in range(maxiter):
if spline is None:
mask = np.ones_like(time, dtype=np.bool) # Try to fit all points.
else:
# Choose points where the absolute deviation from the median residual is
# less than outlier_cut*sigma, where sigma is a robust estimate of the
# standard deviation of the residuals from the previous spline.
residuals = flux - spline
new_mask = robust_mean.robust_mean(residuals, cut=outlier_cut)[2]
if np.all(new_mask == mask):
break # Spline converged.
mask = new_mask
if np.sum(mask) < 4:
# Fewer than 4 points after removing outliers. We could plausibly return
# the spline from the previous iteration because it was fit with at least
# 4 points. However, since the outliers were such a significant fraction
# of the curve, the spline from the previous iteration is probably junk,
# and we consider this a fatal error.
raise InsufficientPointsError(
"Cannot fit a spline on less than 4 points. After removing "
"outliers, got {} points.".format(np.sum(mask)))
try:
with warnings.catch_warnings():
# Suppress warning messages printed by pydlutils.bspline. Instead we
# catch any exception and raise a more informative error.
warnings.simplefilter("ignore")
# Fit the spline on non-outlier points.
curve = bspline.iterfit(time[mask], flux[mask], bkspace=bkspace)[0]
# Evaluate spline at the time points.
spline = curve.value(time)[0]
except (IndexError, TypeError) as e:
raise SplineError(
"Fitting spline failed with error: '{}'. This might be caused by the "
"breakpoint spacing being too small, and/or there being insufficient "
"points to fit the spline in one of the intervals.".format(e))
return spline, mask
class SplineMetadata(object):
"""Metadata about a spline fit.
Attributes:
light_curve_mask: List of boolean numpy arrays indicating which points in
the light curve were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
likelihood_term: The likelihood term of the Bayesian Information Criterion;
-2*ln(L), where L is the likelihood of the data given the model.
penalty_term: The penalty term for the number of parameters in the Bayesian
Information Criterion.
bic: The value of the Bayesian Information Criterion; equal to
likelihood_term + penalty_coeff * penalty_term.
"""
def __init__(self):
self.light_curve_mask = None
self.bkspace = None
self.bad_bkspaces = []
self.likelihood_term = None
self.penalty_term = None
self.bic = None
def choose_kepler_spline(all_time,
all_flux,
bkspaces,
maxiter=5,
penalty_coeff=1.0,
verbose=True):
"""Computes the best-fit Kepler spline across a break-point spacings.
Some Kepler light curves have low-frequency variability, while others have
very high-frequency variability (e.g. due to rapid rotation). Therefore, it is
suboptimal to use the same break-point spacing for every star. This function
computes the best-fit spline by fitting splines with different break-point
spacings, calculating the Bayesian Information Criterion (BIC) for each
spline, and choosing the break-point spacing that minimizes the BIC.
This function assumes a piecewise light curve, that is, a light curve that is
divided into different segments (e.g. split by quarter breaks or gaps in the
in the data). A separate spline is fit for each segment.
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux values of the light curve.
bkspaces: List of break-point spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
penalty_coeff: Coefficient of the penalty term for using more parameters in
the Bayesian Information Criterion. Decreasing this value will allow more
parameters to be used (i.e. smaller break-point spacing), and vice-versa.
verbose: Whether to log individual spline errors. Note that if bkspaces
contains many values (particularly small ones) then this may cause logging
pollution if calling this function for many light curves.
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
metadata: Object containing metadata about the spline fit.
"""
# Initialize outputs.
best_spline = None
metadata = SplineMetadata()
# Compute the assumed standard deviation of Gaussian white noise about the
# spline model. We assume that each flux value f[i] is a Gaussian random
# variable f[i] ~ N(s[i], sigma^2), where s is the value of the true spline
# model and sigma is the constant standard deviation for all flux values.
# Moreover, we assume that s[i] ~= s[i+1]. Therefore,
# (f[i+1] - f[i]) / sqrt(2) ~ N(0, sigma^2).
scaled_diffs = [np.diff(f) / np.sqrt(2) for f in all_flux]
scaled_diffs = np.concatenate(scaled_diffs) if scaled_diffs else np.array([])
if not scaled_diffs.size:
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
np.zeros_like(f, dtype=np.bool) for f in all_flux
]
return best_spline, metadata
# Compute the median absoute deviation as a robust estimate of sigma. The
# conversion factor of 1.48 takes the median absolute deviation to the
# standard deviation of a normal distribution. See, e.g.
# https://www.mathworks.com/help/stats/mad.html.
sigma = np.median(np.abs(scaled_diffs)) * 1.48
for bkspace in bkspaces:
nparams = 0 # Total number of free parameters in the piecewise spline.
npoints = 0 # Total number of data points used to fit the piecewise spline.
ssr = 0 # Sum of squared residuals between the model and the spline.
spline = []
light_curve_mask = []
bad_bkspace = False # Indicates that the current bkspace should be skipped.
for time, flux in zip(all_time, all_flux):
# Fit B-spline to this light-curve segment.
try:
spline_piece, mask = kepler_spline(
time, flux, bkspace=bkspace, maxiter=maxiter)
except InsufficientPointsError as e:
# It's expected to occasionally see intervals with insufficient points,
# especially if periodic signals have been removed from the light curve.
# Skip this interval, but continue fitting the spline.
if verbose:
warnings.warn(str(e))
spline.append(np.array([np.nan] * len(flux)))
light_curve_mask.append(np.zeros_like(flux, dtype=np.bool))
continue
except SplineError as e:
# It's expected to get a SplineError occasionally for small values of
# bkspace. Skip this bkspace.
if verbose:
warnings.warn("Bad bkspace {}: {}".format(bkspace, e))
metadata.bad_bkspaces.append(bkspace)
bad_bkspace = True
break
spline.append(spline_piece)
light_curve_mask.append(mask)
# Accumulate the number of free parameters.
total_time = np.max(time) - np.min(time)
nknots = int(total_time / bkspace) + 1 # From the bspline implementation.
nparams += nknots + 3 - 1 # number of knots + degree of spline - 1
# Accumulate the number of points and the squared residuals.
npoints += np.sum(mask)
ssr += np.sum((flux[mask] - spline_piece[mask])**2)
if bad_bkspace or not npoints:
continue
# The following term is -2*ln(L), where L is the likelihood of the data
# given the model, under the assumption that the model errors are iid
# Gaussian with mean 0 and standard deviation sigma.
likelihood_term = npoints * np.log(2 * np.pi * sigma**2) + ssr / sigma**2
# Penalty term for the number of parameters used to fit the model.
penalty_term = nparams * np.log(npoints)
# Bayesian information criterion.
bic = likelihood_term + penalty_coeff * penalty_term
if best_spline is None or bic < metadata.bic:
best_spline = spline
metadata.light_curve_mask = light_curve_mask
metadata.bkspace = bkspace
metadata.likelihood_term = likelihood_term
metadata.penalty_term = penalty_term
metadata.bic = bic
if best_spline is None:
# All bkspaces resulted in a SplineError, or all light curve intervals had
# insufficient points.
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
np.zeros_like(f, dtype=np.bool) for f in all_flux
]
return best_spline, metadata
def fit_kepler_spline(all_time,
all_flux,
bkspace_min=0.5,
bkspace_max=20,
bkspace_num=20,
maxiter=5,
penalty_coeff=1.0,
verbose=True):
"""Fits a Kepler spline with logarithmically-sampled breakpoint spacings.
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux values of the light curve.
bkspace_min: Minimum breakpoint spacing to try.
bkspace_max: Maximum breakpoint spacing to try.
bkspace_num: Number of breakpoint spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
penalty_coeff: Coefficient of the penalty term for using more parameters in
the Bayesian Information Criterion. Decreasing this value will allow more
parameters to be used (i.e. smaller break-point spacing), and vice-versa.
verbose: Whether to log individual spline errors. Note that if bkspaces
contains many values (particularly small ones) then this may cause logging
pollution if calling this function for many light curves.
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
metadata: Object containing metadata about the spline fit.
"""
# Logarithmically sample bkspace_num candidate break point spacings between
# bkspace_min and bkspace_max.
bkspaces = np.logspace(
np.log10(bkspace_min), np.log10(bkspace_max), num=bkspace_num)
return choose_kepler_spline(
all_time,
all_flux,
bkspaces,
maxiter=maxiter,
penalty_coeff=penalty_coeff,
verbose=verbose)
"""Tests for kepler_spline.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from third_party.kepler_spline import kepler_spline
class KeplerSplineTest(absltest.TestCase):
def testFitSine(self):
# Fit a sine wave.
time = np.arange(0, 10, 0.1)
flux = np.sin(time)
# Expect very close fit with no outliers removed.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=0.5)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-4)
self.assertTrue(np.all(mask))
# Add some outliers.
flux[35] = 10
flux[77] = -3
flux[95] = 2.9
# Expect a close fit with outliers removed.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=0.5)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-4)
self.assertEqual(np.sum(mask), 97)
self.assertFalse(mask[35])
self.assertFalse(mask[77])
self.assertFalse(mask[95])
# Increase breakpoint spacing. Fit is not quite as close.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=1)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 2e-3)
self.assertEqual(np.sum(mask), 97)
self.assertFalse(mask[35])
self.assertFalse(mask[77])
self.assertFalse(mask[95])
def testFitCubic(self):
# Fit a cubic polynomial.
time = np.arange(0, 10, 0.1)
flux = (time - 5)**3 + 2 * (time - 5)**2 + 10
# Expect very close fit with no outliers removed. We choose maxiter=1,
# because a cubic spline will fit a cubic polynomial ~exactly, so the
# standard deviation of residuals will be ~0, which will cause some closely
# fit points to be rejected.
spline, mask = kepler_spline.kepler_spline(
time, flux, bkspace=0.5, maxiter=1)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-12)
self.assertTrue(np.all(mask))
def testInsufficientPointsError(self):
# Empty light curve.
time = np.array([])
flux = np.array([])
with self.assertRaises(kepler_spline.InsufficientPointsError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5)
# Only 3 points.
time = np.array([0.1, 0.2, 0.3])
flux = np.sin(time)
with self.assertRaises(kepler_spline.InsufficientPointsError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5)
class ChooseKeplerSplineTest(absltest.TestCase):
def testEmptyInput(self):
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time=[],
all_flux=[],
bkspaces=bkspaces,
penalty_coeff=1.0,
verbose=False)
np.testing.assert_array_equal(spline, [])
np.testing.assert_array_equal(metadata.light_curve_mask, [])
def testNoPoints(self):
all_time = [np.array([])]
all_flux = [np.array([])]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
np.testing.assert_array_equal(spline, [[]])
np.testing.assert_array_equal(metadata.light_curve_mask, [[]])
def testTooFewPoints(self):
# Sine wave with segments of 1, 2, 3 points.
all_time = [
np.array([0.1]),
np.array([0.2, 0.3]),
np.array([0.4, 0.5, 0.6])
]
all_flux = [np.sin(t) for t in all_time]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
# All segments are NaN.
self.assertTrue(np.all(np.isnan(np.concatenate(spline))))
self.assertFalse(np.any(np.concatenate(metadata.light_curve_mask)))
self.assertIsNone(metadata.bkspace)
self.assertEmpty(metadata.bad_bkspaces)
self.assertIsNone(metadata.likelihood_term)
self.assertIsNone(metadata.penalty_term)
self.assertIsNone(metadata.bic)
# Add a longer segment.
all_time.append(np.arange(0.7, 2.0, 0.1))
all_flux.append(np.sin(all_time[-1]))
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
# First 3 segments are NaN.
for i in range(3):
self.assertTrue(np.all(np.isnan(spline[i])))
self.assertFalse(np.any(metadata.light_curve_mask[i]))
# Final segment is a good fit.
self.assertTrue(np.all(np.isfinite(spline[3])))
self.assertTrue(np.all(metadata.light_curve_mask[3]))
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -58.0794069927957)
self.assertAlmostEqual(metadata.penalty_term, 7.69484807238461)
self.assertAlmostEqual(metadata.bic, -50.3845589204111)
def testFitSine(self):
# High frequency sine wave.
all_time = [np.arange(0, 100, 0.1), np.arange(100, 200, 0.1)]
all_flux = [np.sin(t) for t in all_time]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
def _rmse(all_flux, all_spline):
f = np.concatenate(all_flux)
s = np.concatenate(all_spline)
return np.sqrt(np.mean((f - s)**2))
# Penalty coefficient 1.0.
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0)
self.assertAlmostEqual(_rmse(all_flux, spline), 0.013013)
self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(metadata.bkspace, 1.67990914314)
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6685.64217856480)
self.assertAlmostEqual(metadata.penalty_term, 942.51190498322)
self.assertAlmostEqual(metadata.bic, -5743.13027358158)
# Decrease penalty coefficient; allow smaller spacing for closer fit.
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=0.1)
self.assertAlmostEqual(_rmse(all_flux, spline), 0.0066376)
self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(metadata.bkspace, 1.48817572082)
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6731.59913975551)
self.assertAlmostEqual(metadata.penalty_term, 1064.12634433589)
self.assertAlmostEqual(metadata.bic, -6625.18650532192)
# Increase penalty coefficient; require larger spacing at the cost of worse
# fit.
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=2)
self.assertAlmostEqual(_rmse(all_flux, spline), 0.026215449)
self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(metadata.bkspace, 1.89634509537)
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6495.65564287904)
self.assertAlmostEqual(metadata.penalty_term, 836.099270549629)
self.assertAlmostEqual(metadata.bic, -4823.45710177978)
if __name__ == "__main__":
absltest.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # BSD
exports_files(["LICENSE"])
py_library(
name = "robust_mean",
srcs = ["robust_mean.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "robust_mean_test",
size = "small",
srcs = [
"robust_mean_test.py",
"test_data/random_normal.py",
],
srcs_version = "PY2AND3",
deps = [":robust_mean"],
)
Copyright (c) 2014, Wayne Landsman
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment