Unverified Commit d1b91ba9 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #4554 from cshallue/master

Merge internal changes
parents 1abcee90 8ae37506
...@@ -18,8 +18,6 @@ from __future__ import absolute_import ...@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import numpy as np import numpy as np
from six.moves import range # pylint:disable=redefined-builtin from six.moves import range # pylint:disable=redefined-builtin
...@@ -51,10 +49,10 @@ def split(all_time, all_flux, gap_width=0.75): ...@@ -51,10 +49,10 @@ def split(all_time, all_flux, gap_width=0.75):
piecewise defined (e.g. split by quarter breaks or gaps in the in the data). piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args: Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time all_time: Numpy array or sequence of numpy arrays; each is a sequence of
values. time values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
values of the corresponding time array. flux values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split. gap_width: Minimum gap size (in time units) for a split.
Returns: Returns:
...@@ -62,10 +60,7 @@ def split(all_time, all_flux, gap_width=0.75): ...@@ -62,10 +60,7 @@ def split(all_time, all_flux, gap_width=0.75):
out_flux: List of numpy arrays; the split flux arrays. out_flux: List of numpy arrays; the split flux arrays.
""" """
# Handle single-segment inputs. # Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
# to bool fails if all_time is a numpy array, and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
all_time = [all_time] all_time = [all_time]
all_flux = [all_flux] all_flux = [all_flux]
...@@ -90,10 +85,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0): ...@@ -90,10 +85,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
curve (e.g. one that is split by quarter breaks or gaps in the in the data). curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args: Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time all_time: Numpy array or sequence of numpy arrays; each is a sequence of
values. time values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
values of the corresponding time array. flux values of the corresponding time array.
events: List of Event objects to remove. events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove. width_factor: Fractional multiplier of the duration of each event to remove.
...@@ -104,10 +99,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0): ...@@ -104,10 +99,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
events removed. events removed.
""" """
# Handle single-segment inputs. # Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
# to bool fails if all_time is a numpy array and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
all_time = [all_time] all_time = [all_time]
all_flux = [all_flux] all_flux = [all_flux]
single_segment = True single_segment = True
...@@ -150,10 +142,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline): ...@@ -150,10 +142,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
interp_spline = [] interp_spline = []
for time, masked_time, masked_spline in zip( for time, masked_time, masked_spline in zip(
all_time, all_masked_time, all_masked_spline): all_time, all_masked_time, all_masked_spline):
if len(masked_time) > 0: # pylint:disable=g-explicit-length-test if masked_time.size:
interp_spline.append(np.interp(time, masked_time, masked_spline)) interp_spline.append(np.interp(time, masked_time, masked_spline))
else: else:
interp_spline.append(np.full_like(time, np.nan)) interp_spline.append(np.array([np.nan] * len(time)))
return interp_spline return interp_spline
......
...@@ -65,43 +65,63 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -65,43 +65,63 @@ class LightCurveUtilTest(absltest.TestCase):
self.assertSequenceAlmostEqual(expected, tfold) self.assertSequenceAlmostEqual(expected, tfold)
def testSplit(self): def testSplit(self):
# Single segment.
all_time = np.concatenate([np.arange(0, 1, 0.1), np.arange(1.5, 2, 0.1)])
all_flux = np.ones(15)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 2)
self.assertLen(split_flux, 2)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
# Multi segment.
all_time = [ all_time = [
np.concatenate([ np.concatenate([
np.arange(0, 1, 0.1), np.arange(0, 1, 0.1),
np.arange(1.5, 2, 0.1), np.arange(1.5, 2, 0.1),
np.arange(3, 4, 0.1) np.arange(3, 4, 0.1)
]) ]),
np.arange(4, 5, 0.1)
] ]
all_flux = [np.array([1] * 25)] all_flux = [np.ones(25), np.ones(10)]
self.assertEqual(len(all_time), 2)
self.assertEqual(len(all_time[0]), 25)
self.assertEqual(len(all_time[1]), 10)
self.assertEqual(len(all_flux), 2)
self.assertEqual(len(all_flux[0]), 25)
self.assertEqual(len(all_flux[1]), 10)
# Gap width 0.5. # Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5) split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 3) self.assertLen(split_time, 4)
self.assertLen(split_flux, 3) self.assertLen(split_flux, 4)
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
[0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], split_time[0]) self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
split_flux[0]) self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
self.assertSequenceAlmostEqual([1.5, 1.6, 1.7, 1.8, 1.9], split_time[1]) self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[2])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1], split_flux[1]) self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[3])
[3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], split_time[2]) self.assertSequenceAlmostEqual(np.ones(10), split_flux[3])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[2])
# Gap width 1.0. # Gap width 1.0.
split_time, split_flux = util.split(all_time, all_flux, gap_width=1) split_time, split_flux = util.split(all_time, all_flux, gap_width=1)
self.assertLen(split_time, 2) self.assertLen(split_time, 3)
self.assertLen(split_flux, 2) self.assertLen(split_flux, 3)
self.assertSequenceAlmostEqual([ self.assertSequenceAlmostEqual([
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.5, 1.6, 1.7, 1.8, 1.9 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.5, 1.6, 1.7, 1.8, 1.9
], split_time[0]) ], split_time[0])
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(np.ones(15), split_flux[0])
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], split_flux[0]) self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[1])
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(np.ones(10), split_flux[1])
[3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], split_time[1]) self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[2])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
split_flux[1])
def testRemoveEvents(self): def testRemoveEvents(self):
time = np.arange(20, dtype=np.float) time = np.arange(20, dtype=np.float)
...@@ -136,20 +156,23 @@ class LightCurveUtilTest(absltest.TestCase): ...@@ -136,20 +156,23 @@ class LightCurveUtilTest(absltest.TestCase):
all_time = [ all_time = [
np.arange(0, 10, dtype=np.float), np.arange(0, 10, dtype=np.float),
np.arange(10, 20, dtype=np.float), np.arange(10, 20, dtype=np.float),
np.arange(20, 30, dtype=np.float),
] ]
all_masked_time = [ all_masked_time = [
np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7 np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7
np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19 np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19
np.array([], dtype=np.float)
] ]
all_masked_spline = [2 * t + 100 for t in all_masked_time] all_masked_spline = [2 * t + 100 for t in all_masked_time]
interp_spline = util.interpolate_masked_spline(all_time, all_masked_time, interp_spline = util.interpolate_masked_spline(all_time, all_masked_time,
all_masked_spline) all_masked_spline)
self.assertLen(interp_spline, 2) self.assertLen(interp_spline, 3)
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(
[100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0]) [100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0])
self.assertSequenceAlmostEqual( self.assertSequenceAlmostEqual(
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1]) [120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
self.assertTrue(np.all(np.isnan(interp_spline[2])))
def testCountTransitPoints(self): def testCountTransitPoints(self):
time = np.concatenate([ time = np.concatenate([
......
...@@ -12,8 +12,13 @@ from pydl.pydlutils import bspline ...@@ -12,8 +12,13 @@ from pydl.pydlutils import bspline
from third_party.robust_mean import robust_mean from third_party.robust_mean import robust_mean
class InsufficientPointsError(Exception):
"""Indicates that insufficient points were available for spline fitting."""
pass
class SplineError(Exception): class SplineError(Exception):
"""Error when fitting a Kepler spline.""" """Indicates an error in the underlying spline-fitting implementation."""
pass pass
...@@ -40,7 +45,17 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3): ...@@ -40,7 +45,17 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
spline: The values of the fitted spline corresponding to the input time spline: The values of the fitted spline corresponding to the input time
values. values.
mask: Boolean mask indicating the points used to fit the final spline. mask: Boolean mask indicating the points used to fit the final spline.
Raises:
InsufficientPointsError: If there were insufficient points (after removing
outliers) for spline fitting.
SplineError: If the spline could not be fit, for example if the breakpoint
spacing is too small.
""" """
if len(time) < 4:
raise InsufficientPointsError(
"Cannot fit a spline on less than 4 points. Got %d points." % len(time))
# Rescale time into [0, 1]. # Rescale time into [0, 1].
t_min = np.min(time) t_min = np.min(time)
t_max = np.max(time) t_max = np.max(time)
...@@ -58,16 +73,26 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3): ...@@ -58,16 +73,26 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
mask = np.ones_like(time, dtype=np.bool) # Try to fit all points. mask = np.ones_like(time, dtype=np.bool) # Try to fit all points.
else: else:
# Choose points where the absolute deviation from the median residual is # Choose points where the absolute deviation from the median residual is
# less than 3*sigma, where sigma is a robust estimate of the standard # less than outlier_cut*sigma, where sigma is a robust estimate of the
# deviation of the residuals from the previous spline. # standard deviation of the residuals from the previous spline.
residuals = flux - spline residuals = flux - spline
_, _, new_mask = robust_mean.robust_mean(residuals, cut=outlier_cut) new_mask = robust_mean.robust_mean(residuals, cut=outlier_cut)[2]
if np.all(new_mask == mask): if np.all(new_mask == mask):
break # Spline converged. break # Spline converged.
mask = new_mask mask = new_mask
if np.sum(mask) < 4:
# Fewer than 4 points after removing outliers. We could plausibly return
# the spline from the previous iteration because it was fit with at least
# 4 points. However, since the outliers were such a significant fraction
# of the curve, the spline from the previous iteration is probably junk,
# and we consider this a fatal error.
raise InsufficientPointsError(
"Cannot fit a spline on less than 4 points. After removing "
"outliers, got %d points." % np.sum(mask))
try: try:
with warnings.catch_warnings(): with warnings.catch_warnings():
# Suppress warning messages printed by pydlutils.bspline. Instead we # Suppress warning messages printed by pydlutils.bspline. Instead we
...@@ -88,6 +113,31 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3): ...@@ -88,6 +113,31 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
return spline, mask return spline, mask
class SplineMetadata(object):
"""Metadata about a spline fit.
Attributes:
light_curve_mask: List of boolean numpy arrays indicating which points in
the light curve were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
likelihood_term: The likelihood term of the Bayesian Information Criterion;
-2*ln(L), where L is the likelihood of the data given the model.
penalty_term: The penalty term for the number of parameters in the
Bayesian Information Criterion.
bic: The value of the Bayesian Information Criterion; equal to
likelihood_term + penalty_coeff * penalty_term.
"""
def __init__(self):
self.light_curve_mask = None
self.bkspace = None
self.bad_bkspaces = []
self.likelihood_term = None
self.penalty_term = None
self.bic = None
def choose_kepler_spline(all_time, def choose_kepler_spline(all_time,
all_flux, all_flux,
bkspaces, bkspaces,
...@@ -125,52 +175,65 @@ def choose_kepler_spline(all_time, ...@@ -125,52 +175,65 @@ def choose_kepler_spline(all_time,
Returns: Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays. to the input flux arrays.
spline_mask: List of boolean numpy arrays indicating which points in the metadata: Object containing metadata about the spline fit.
flux arrays were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
""" """
# Initialize outputs.
best_spline = None
metadata = SplineMetadata()
# Compute the assumed standard deviation of Gaussian white noise about the # Compute the assumed standard deviation of Gaussian white noise about the
# spline model. # spline model. We assume that each flux value f[i] is a Gaussian random
abs_deviations = np.concatenate([np.abs(f[1:] - f[:-1]) for f in all_flux]) # variable f[i] ~ N(s[i], sigma^2), where s is the value of the true spline
sigma = np.median(abs_deviations) * 1.48 / np.sqrt(2) # model and sigma is the constant standard deviation for all flux values.
# Moreover, we assume that s[i] ~= s[i+1]. Therefore,
# (f[i+1] - f[i]) / sqrt(2) ~ N(0, sigma^2).
scaled_diffs = np.concatenate([np.diff(f) / np.sqrt(2) for f in all_flux])
if not scaled_diffs.size:
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
np.zeros_like(f, dtype=np.bool) for f in all_flux
]
return best_spline, metadata
# Compute the median absoute deviation as a robust estimate of sigma. The
# conversion factor of 1.48 takes the median absolute deviation to the
# standard deviation of a normal distribution. See, e.g.
# https://www.mathworks.com/help/stats/mad.html.
sigma = np.median(np.abs(scaled_diffs)) * 1.48
best_bic = None
best_spline = None
best_spline_mask = None
best_bkspace = None
bad_bkspaces = []
for bkspace in bkspaces: for bkspace in bkspaces:
nparams = 0 # Total number of free parameters in the piecewise spline. nparams = 0 # Total number of free parameters in the piecewise spline.
npoints = 0 # Total number of data points used to fit the piecewise spline. npoints = 0 # Total number of data points used to fit the piecewise spline.
ssr = 0 # Sum of squared residuals between the model and the spline. ssr = 0 # Sum of squared residuals between the model and the spline.
spline = [] spline = []
spline_mask = [] light_curve_mask = []
bad_bkspace = False # Indicates that the current bkspace should be skipped. bad_bkspace = False # Indicates that the current bkspace should be skipped.
for time, flux in zip(all_time, all_flux): for time, flux in zip(all_time, all_flux):
# Don't fit a spline on less than 4 points.
if len(time) < 4:
spline.append(flux)
spline_mask.append(np.ones_like(flux), dtype=np.bool)
continue
# Fit B-spline to this light-curve segment. # Fit B-spline to this light-curve segment.
try: try:
spline_piece, mask = kepler_spline( spline_piece, mask = kepler_spline(
time, flux, bkspace=bkspace, maxiter=maxiter) time, flux, bkspace=bkspace, maxiter=maxiter)
except InsufficientPointsError as e:
# It's expected to get a SplineError occasionally for small values of # It's expected to occasionally see intervals with insufficient points,
# bkspace. # especially if periodic signals have been removed from the light curve.
# Skip this interval, but continue fitting the spline.
if verbose:
warnings.warn(str(e))
spline.append(np.array([np.nan] * len(flux)))
light_curve_mask.append(np.zeros_like(flux, dtype=np.bool))
continue
except SplineError as e: except SplineError as e:
# It's expected to get a SplineError occasionally for small values of
# bkspace. Skip this bkspace.
if verbose: if verbose:
warnings.warn("Bad bkspace %.4f: %s" % (bkspace, e)) warnings.warn("Bad bkspace %.4f: %s" % (bkspace, e))
bad_bkspaces.append(bkspace) metadata.bad_bkspaces.append(bkspace)
bad_bkspace = True bad_bkspace = True
break break
spline.append(spline_piece) spline.append(spline_piece)
spline_mask.append(mask) light_curve_mask.append(mask)
# Accumulate the number of free parameters. # Accumulate the number of free parameters.
total_time = np.max(time) - np.min(time) total_time = np.max(time) - np.min(time)
...@@ -181,7 +244,7 @@ def choose_kepler_spline(all_time, ...@@ -181,7 +244,7 @@ def choose_kepler_spline(all_time,
npoints += np.sum(mask) npoints += np.sum(mask)
ssr += np.sum((flux[mask] - spline_piece[mask])**2) ssr += np.sum((flux[mask] - spline_piece[mask])**2)
if bad_bkspace: if bad_bkspace or not npoints:
continue continue
# The following term is -2*ln(L), where L is the likelihood of the data # The following term is -2*ln(L), where L is the likelihood of the data
...@@ -189,13 +252,26 @@ def choose_kepler_spline(all_time, ...@@ -189,13 +252,26 @@ def choose_kepler_spline(all_time,
# Gaussian with mean 0 and standard deviation sigma. # Gaussian with mean 0 and standard deviation sigma.
likelihood_term = npoints * np.log(2 * np.pi * sigma**2) + ssr / sigma**2 likelihood_term = npoints * np.log(2 * np.pi * sigma**2) + ssr / sigma**2
# Penalty term for the number of parameters used to fit the model.
penalty_term = nparams * np.log(npoints)
# Bayesian information criterion. # Bayesian information criterion.
bic = likelihood_term + penalty_coeff * nparams * np.log(npoints) bic = likelihood_term + penalty_coeff * penalty_term
if best_bic is None or bic < best_bic: if best_spline is None or bic < metadata.bic:
best_bic = bic
best_spline = spline best_spline = spline
best_spline_mask = spline_mask metadata.light_curve_mask = light_curve_mask
best_bkspace = bkspace metadata.bkspace = bkspace
metadata.likelihood_term = likelihood_term
return best_spline, best_spline_mask, best_bkspace, bad_bkspaces metadata.penalty_term = penalty_term
metadata.bic = bic
if best_spline is None:
# All bkspaces resulted in a SplineError, or all light curve intervals had
# insufficient points.
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
np.zeros_like(f, dtype=np.bool) for f in all_flux
]
return best_spline, metadata
...@@ -12,7 +12,7 @@ from third_party.kepler_spline import kepler_spline ...@@ -12,7 +12,7 @@ from third_party.kepler_spline import kepler_spline
class KeplerSplineTest(absltest.TestCase): class KeplerSplineTest(absltest.TestCase):
def testKeplerSplineSine(self): def testFitSine(self):
# Fit a sine wave. # Fit a sine wave.
time = np.arange(0, 10, 0.1) time = np.arange(0, 10, 0.1)
flux = np.sin(time) flux = np.sin(time)
...@@ -46,7 +46,7 @@ class KeplerSplineTest(absltest.TestCase): ...@@ -46,7 +46,7 @@ class KeplerSplineTest(absltest.TestCase):
self.assertFalse(mask[77]) self.assertFalse(mask[77])
self.assertFalse(mask[95]) self.assertFalse(mask[95])
def testKeplerSplineCubic(self): def testFitCubic(self):
# Fit a cubic polynomial. # Fit a cubic polynomial.
time = np.arange(0, 10, 0.1) time = np.arange(0, 10, 0.1)
flux = (time - 5)**3 + 2 * (time - 5)**2 + 10 flux = (time - 5)**3 + 2 * (time - 5)**2 + 10
...@@ -61,18 +61,84 @@ class KeplerSplineTest(absltest.TestCase): ...@@ -61,18 +61,84 @@ class KeplerSplineTest(absltest.TestCase):
self.assertLess(rmse, 1e-12) self.assertLess(rmse, 1e-12)
self.assertTrue(np.all(mask)) self.assertTrue(np.all(mask))
def testKeplerSplineError(self): def testInsufficientPointsError(self):
# Big gap. # Empty light curve.
time = np.concatenate([np.arange(0, 1, 0.1), [2]]) time = np.array([])
flux = np.array([])
with self.assertRaises(kepler_spline.InsufficientPointsError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5)
# Only 3 points.
time = np.array([0.1, 0.2, 0.3])
flux = np.sin(time) flux = np.sin(time)
with self.assertRaises(kepler_spline.SplineError): with self.assertRaises(kepler_spline.InsufficientPointsError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5) kepler_spline.kepler_spline(time, flux, bkspace=0.5)
def testChooseKeplerSpline(self):
class ChooseKeplerSplineTest(absltest.TestCase):
def testNoPoints(self):
all_time = [np.array([])]
all_flux = [np.array([])]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
np.testing.assert_array_equal(spline, [[]])
np.testing.assert_array_equal(metadata.light_curve_mask, [[]])
def testTooFewPoints(self):
# Sine wave with segments of 1, 2, 3 points.
all_time = [
np.array([0.1]),
np.array([0.2, 0.3]),
np.array([0.4, 0.5, 0.6])
]
all_flux = [np.sin(t) for t in all_time]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
# All segments are NaN.
self.assertTrue(np.all(np.isnan(np.concatenate(spline))))
self.assertFalse(np.any(np.concatenate(metadata.light_curve_mask)))
self.assertIsNone(metadata.bkspace)
self.assertEmpty(metadata.bad_bkspaces)
self.assertIsNone(metadata.likelihood_term)
self.assertIsNone(metadata.penalty_term)
self.assertIsNone(metadata.bic)
# Add a longer segment.
all_time.append(np.arange(0.7, 2.0, 0.1))
all_flux.append(np.sin(all_time[-1]))
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
# First 3 segments are NaN.
for i in range(3):
self.assertTrue(np.all(np.isnan(spline[i])))
self.assertFalse(np.any(metadata.light_curve_mask[i]))
# Final segment is a good fit.
self.assertTrue(np.all(np.isfinite(spline[3])))
self.assertTrue(np.all(metadata.light_curve_mask[3]))
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -58.0794069927957)
self.assertAlmostEqual(metadata.penalty_term, 7.69484807238461)
self.assertAlmostEqual(metadata.bic, -50.3845589204111)
def testFitSine(self):
# High frequency sine wave. # High frequency sine wave.
time = [np.arange(0, 100, 0.1), np.arange(100, 200, 0.1)] all_time = [np.arange(0, 100, 0.1), np.arange(100, 200, 0.1)]
flux = [np.sin(t) for t in time] all_flux = [np.sin(t) for t in all_time]
# Logarithmically sample candidate break point spacings. # Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20) bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
...@@ -83,29 +149,38 @@ class KeplerSplineTest(absltest.TestCase): ...@@ -83,29 +149,38 @@ class KeplerSplineTest(absltest.TestCase):
return np.sqrt(np.mean((f - s)**2)) return np.sqrt(np.mean((f - s)**2))
# Penalty coefficient 1.0. # Penalty coefficient 1.0.
spline, mask, bkspace, bad_bkspaces = kepler_spline.choose_kepler_spline( spline, metadata = kepler_spline.choose_kepler_spline(
time, flux, bkspaces, penalty_coeff=1.0) all_time, all_flux, bkspaces, penalty_coeff=1.0)
self.assertAlmostEqual(_rmse(flux, spline), 0.013013) self.assertAlmostEqual(_rmse(all_flux, spline), 0.013013)
self.assertTrue(np.all(mask)) self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(bkspace, 1.67990914314) self.assertAlmostEqual(metadata.bkspace, 1.67990914314)
self.assertEmpty(bad_bkspaces) self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6685.64217856480)
self.assertAlmostEqual(metadata.penalty_term, 942.51190498322)
self.assertAlmostEqual(metadata.bic, -5743.13027358158)
# Decrease penalty coefficient; allow smaller spacing for closer fit. # Decrease penalty coefficient; allow smaller spacing for closer fit.
spline, mask, bkspace, bad_bkspaces = kepler_spline.choose_kepler_spline( spline, metadata = kepler_spline.choose_kepler_spline(
time, flux, bkspaces, penalty_coeff=0.1) all_time, all_flux, bkspaces, penalty_coeff=0.1)
self.assertAlmostEqual(_rmse(flux, spline), 0.0066376) self.assertAlmostEqual(_rmse(all_flux, spline), 0.0066376)
self.assertTrue(np.all(mask)) self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(bkspace, 1.48817572082) self.assertAlmostEqual(metadata.bkspace, 1.48817572082)
self.assertEmpty(bad_bkspaces) self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6731.59913975551)
self.assertAlmostEqual(metadata.penalty_term, 1064.12634433589)
self.assertAlmostEqual(metadata.bic, -6625.18650532192)
# Increase penalty coefficient; require larger spacing at the cost of worse # Increase penalty coefficient; require larger spacing at the cost of worse
# fit. # fit.
spline, mask, bkspace, bad_bkspaces = kepler_spline.choose_kepler_spline( spline, metadata = kepler_spline.choose_kepler_spline(
time, flux, bkspaces, penalty_coeff=2) all_time, all_flux, bkspaces, penalty_coeff=2)
self.assertAlmostEqual(_rmse(flux, spline), 0.026215449) self.assertAlmostEqual(_rmse(all_flux, spline), 0.026215449)
self.assertTrue(np.all(mask)) self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(bkspace, 1.89634509537) self.assertAlmostEqual(metadata.bkspace, 1.89634509537)
self.assertEmpty(bad_bkspaces) self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6495.65564287904)
self.assertAlmostEqual(metadata.penalty_term, 836.099270549629)
self.assertAlmostEqual(metadata.bic, -4823.45710177978)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment