Commit 8b200ebb authored by Christopher Shallue's avatar Christopher Shallue
Browse files

Initial AstroNet commit.

parent 0b74e527
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_VIEW_GENERATOR_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_VIEW_GENERATOR_H_
#include <memory>
#include <string>
#include <vector>
namespace astronet {
// Helper class for phase-folding a light curve and then generating "views" of
// the light curve using a median filter.
//
// This class wraps the functions in light_curve_util.h for intended use as a
// a Python extension. It keeps the phase-folded light curve in the class state
// to minimize expensive copies between the language barrier.
class ViewGenerator {
public:
// Factory function to create a new ViewGenerator.
//
// Input args:
// time: Vector of time values, not phase-folded.
// flux: Vector of flux values with the same size as time.
// period: The period to fold over.
// t0: The center of the resulting folded vector; this value is mapped to 0.
//
// Output args:
// error: String indicating an error (e.g. time and flux are different
// sizes).
//
// Returns:
// A ViewGenerator. May be a nullptr in the case of an error; see the
// "error" string if so.
static std::unique_ptr<ViewGenerator> Create(const std::vector<double>& time,
const std::vector<double>& flux,
double period, double t0,
std::string* error);
// Generates a "view" of the phase-folded light curve using a median filter.
//
// Note that the time values of the phase-folded light curve are in the range
// [-period / 2, period / 2).
//
// This function applies astronet::MedianFilter() to the phase-folded and
// sorted light curve, followed optionally by
// astronet::NormalizeMedianAndMinimum(). See the comments on those
// functions for more details.
//
// Input args:
// num_bins: The number of intervals to divide the time axis into. Must be
// at least 2.
// bin_width: The width of each bin on the time axis. Must be positive, and
// less than t_max - t_min.
// t_min: The inclusive leftmost value to consider on the time axis. This
// should probably be at least -period / 2, which is the minimum
// possible value of the phase-folded light curve. Must be less than the
// largest value of the phase-folded time axis.
// t_max: The exclusive rightmost value to consider on the time axis. This
// should probably be at most period / 2, which is the maximum possible
// value of the phase-folded light curve. Must be greater than t_min.
// normalize: Whether to normalize the output vector to have median 0 and
// minimum -1.
//
// Output args:
// result: Vector of size num_bins containing the median flux values of
// uniformly spaced bins on the phase-folded time axis.
// error: String indicating an error (e.g. an invalid argument).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool GenerateView(int num_bins, double bin_width, double t_min, double t_max,
bool normalize, std::vector<double>* result,
std::string* error);
protected:
// This class can only be constructed by Create().
ViewGenerator(std::vector<double> time, std::vector<double> flux);
// phase-folded light curve, sorted by time in ascending order.
std::vector<double> time_;
std::vector<double> flux_;
};
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_VIEW_GENERATOR_H_
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/view_generator.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h"
using std::vector;
using testing::Pointwise;
namespace astronet {
namespace {
TEST(ViewGenerator, CreationError) {
vector<double> time = {1, 2, 3};
vector<double> flux = {2, 3};
std::string error;
std::unique_ptr<ViewGenerator> generator =
ViewGenerator::Create(time, flux, 1, 0.5, &error);
EXPECT_EQ(nullptr, generator);
EXPECT_FALSE(error.empty());
}
TEST(ViewGenerator, GenerateViews) {
vector<double> time = range(0, 2, 0.1);
vector<double> flux = range(0, 20, 1);
std::string error;
// Create the ViewGenerator.
std::unique_ptr<ViewGenerator> generator =
ViewGenerator::Create(time, flux, 2.0, 0.15, &error);
EXPECT_NE(nullptr, generator);
EXPECT_TRUE(error.empty());
vector<double> result;
// Error: t_max <= t_min. We do not test all failure cases here since they
// are tested in light_curve_util_test.cc.
EXPECT_FALSE(generator->GenerateView(10, 1, -1, -1, false, &result, &error));
EXPECT_FALSE(error.empty());
error.clear();
// Global view, unnormalized.
EXPECT_TRUE(generator->GenerateView(10, 0.2, -1, 1, false, &result, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {12.5, 14.5, 16.5, 18.5, 0.5,
2.5, 4.5, 6.5, 8.5, 10.5};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
// Global view, normalized.
EXPECT_TRUE(generator->GenerateView(10, 0.2, -1, 1, true, &result, &error));
EXPECT_TRUE(error.empty());
expected = {3.0 / 9, 5.0 / 9, 7.0 / 9, 9.0 / 9, -9.0 / 9,
-7.0 / 9, -5.0 / 9, -3.0 / 9, -1.0 / 9, 1.0 / 9};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
// Local view, unnormalized.
EXPECT_TRUE(
generator->GenerateView(5, 0.2, -0.5, 0.5, false, &result, &error));
EXPECT_TRUE(error.empty());
expected = {17.5, 9.5, 1.5, 3.5, 5.5};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
// Local view, normalized.
EXPECT_TRUE(
generator->GenerateView(5, 0.2, -0.5, 0.5, true, &result, &error));
EXPECT_TRUE(error.empty());
expected = {3, 1, -1, -0.5, 0};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
} // namespace
} // namespace astronet
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for reading Kepler data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from astropy.io import fits
import numpy as np
LONG_CADENCE_TIME_DELTA_DAYS = 0.02043422 # Approximately 29.4 minutes.
# Quarter index to filename prefix for long cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
LONG_CADENCE_QUARTER_PREFIXES = {
0: ["2009131105131"],
1: ["2009166043257"],
2: ["2009259160929"],
3: ["2009350155506"],
4: ["2010078095331", "2010009091648"],
5: ["2010174085026"],
6: ["2010265121752"],
7: ["2010355172524"],
8: ["2011073133259"],
9: ["2011177032512"],
10: ["2011271113734"],
11: ["2012004120508"],
12: ["2012088054726"],
13: ["2012179063303"],
14: ["2012277125453"],
15: ["2013011073258"],
16: ["2013098041711"],
17: ["2013131215648"]
}
# Quarter index to filename prefix for short cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
SHORT_CADENCE_QUARTER_PREFIXES = {
0: ["2009131110544"],
1: ["2009166044711"],
2: ["2009201121230", "2009231120729", "2009259162342"],
3: ["2009291181958", "2009322144938", "2009350160919"],
4: ["2010009094841", "2010019161129", "2010049094358", "2010078100744"],
5: ["2010111051353", "2010140023957", "2010174090439"],
6: ["2010203174610", "2010234115140", "2010265121752"],
7: ["2010296114515", "2010326094124", "2010355172524"],
8: ["2011024051157", "2011053090032", "2011073133259"],
9: ["2011116030358", "2011145075126", "2011177032512"],
10: ["2011208035123", "2011240104155", "2011271113734"],
11: ["2011303113607", "2011334093404", "2012004120508"],
12: ["2012032013838", "2012060035710", "2012088054726"],
13: ["2012121044856", "2012151031540", "2012179063303"],
14: ["2012211050319", "2012242122129", "2012277125453"],
15: ["2012310112549", "2012341132017", "2013011073258"],
16: ["2013017113907", "2013065031647", "2013098041711"],
17: ["2013121191144", "2013131215648"]
}
def kepler_filenames(base_dir,
kep_id,
long_cadence=True,
quarters=None,
injected_group=None,
check_existence=True):
"""Returns the light curve filenames for a Kepler target star.
This function assumes the directory structure of the Mikulski Archive for
Space Telescopes (http://archive.stsci.edu/pub/kepler/lightcurves).
Specifically, the filenames for a particular Kepler target star have the
following format:
${kep_id:0:4}/${kep_id}/kplr${kep_id}-${quarter_prefix}_${type}.fits,
where:
kep_id is the Kepler id left-padded with zeros to length 9;
quarter_prefix is the filename quarter prefix;
type is one of "llc" (long cadence light curve) or "slc" (short cadence
light curve).
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
"""
# Pad the Kepler id with zeros to length 9.
kep_id = "%.9d" % int(kep_id)
quarter_prefixes, cadence_suffix = ((LONG_CADENCE_QUARTER_PREFIXES, "llc")
if long_cadence else
(SHORT_CADENCE_QUARTER_PREFIXES, "slc"))
if quarters is None:
quarters = quarter_prefixes.keys()
quarters = sorted(quarters) # Sort quarters chronologically.
filenames = []
base_dir = os.path.join(base_dir, kep_id[0:4], kep_id)
for quarter in quarters:
for quarter_prefix in quarter_prefixes[quarter]:
if injected_group:
base_name = "kplr%s-%s_INJECTED-%s_%s.fits" % (kep_id, quarter_prefix,
injected_group,
cadence_suffix)
else:
base_name = "kplr%s-%s_%s.fits" % (kep_id, quarter_prefix,
cadence_suffix)
filename = os.path.join(base_dir, base_name)
# Not all stars have data for all quarters.
if not check_existence or os.path.isfile(filename):
filenames.append(filename)
return filenames
def read_kepler_light_curve(filenames,
light_curve_extension="LIGHTCURVE",
invert=False):
"""Reads time and flux measurements for a Kepler target star.
Args:
filenames: A list of .fits files containing time and flux measurements.
light_curve_extension: Name of the HDU 1 extension containing light curves.
invert: Whether to invert the flux measurements by multiplying by -1.
Returns:
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
"""
all_time = []
all_flux = []
for filename in filenames:
with fits.open(open(filename, "r")) as hdu_list:
light_curve = hdu_list[light_curve_extension].data
time = light_curve.TIME
flux = light_curve.PDCSAP_FLUX
# Remove NaN flux values.
valid_indices = np.where(np.isfinite(flux))
time = time[valid_indices]
flux = flux[valid_indices]
if invert:
flux *= -1
if time.size:
all_time.append(time)
all_flux.append(flux)
return all_time, all_flux
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kepler_io.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from absl import flags
from absl.testing import absltest
from light_curve_util import kepler_io
FLAGS = flags.FLAGS
_DATA_DIR = "light_curve_util/test_data/"
class KeplerIoTest(absltest.TestCase):
def setUp(self):
self.data_dir = os.path.join(FLAGS.test_srcdir, _DATA_DIR)
def testKeplerFilenames(self):
# All quarters.
filenames = kepler_io.kepler_filenames(
"/my/dir/", 1234567, check_existence=False)
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009131105131_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009166043257_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009259160929_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010174085026_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010265121752_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010355172524_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2011073133259_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2011177032512_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2011271113734_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012004120508_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012088054726_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012179063303_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2012277125453_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2013011073258_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2013098041711_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2013131215648_llc.fits"
], filenames)
# Subset of quarters.
filenames = kepler_io.kepler_filenames(
"/my/dir/", 1234567, quarters=[3, 4], check_existence=False)
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits"
], filenames)
# Injected group.
filenames = kepler_io.kepler_filenames(
"/my/dir/",
1234567,
quarters=[3, 4],
injected_group="inj1",
check_existence=False)
# pylint:disable=line-too-long
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009350155506_INJECTED-inj1_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010078095331_INJECTED-inj1_llc.fits",
"/my/dir/0012/001234567/kplr001234567-2010009091648_INJECTED-inj1_llc.fits"
], filenames)
# pylint:enable=line-too-long
# Short cadence.
filenames = kepler_io.kepler_filenames(
"/my/dir/",
1234567,
long_cadence=False,
quarters=[0, 1],
check_existence=False)
self.assertItemsEqual([
"/my/dir/0012/001234567/kplr001234567-2009131110544_slc.fits",
"/my/dir/0012/001234567/kplr001234567-2009166044711_slc.fits"
], filenames)
# Check existence.
filenames = kepler_io.kepler_filenames(
self.data_dir, 11442793, check_existence=True)
expected_filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
]
self.assertItemsEqual(expected_filenames, filenames)
def testReadKeplerLightCurve(self):
filenames = [
os.path.join(self.data_dir, "0114/011442793/kplr011442793-%s_llc.fits")
% q for q in ["2009350155506", "2010009091648", "2010174085026"]
]
all_time, all_flux = kepler_io.read_kepler_light_curve(filenames)
self.assertLen(all_time, 3)
self.assertLen(all_flux, 3)
self.assertLen(all_time[0], 4134)
self.assertLen(all_flux[0], 4134)
self.assertLen(all_time[1], 1008)
self.assertLen(all_flux[1], 1008)
self.assertLen(all_time[2], 4486)
self.assertLen(all_flux[2], 4486)
if __name__ == "__main__":
FLAGS.test_srcdir = ""
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility function for smoothing data using a median filter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
"""Computes the median y-value in uniform intervals (bins) along the x-axis.
The interval [x_min, x_max) is divided into num_bins uniformly spaced
intervals of width bin_width. The value computed for each bin is the median
of all y-values whose corresponding x-value is in the interval.
NOTE: x must be sorted in ascending order or the results will be incorrect.
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
spaced bins on the x-axis.
Raises:
ValueError: If an argument has an inappropriate value.
"""
if num_bins < 2:
raise ValueError("num_bins must be at least 2. Got: %d" % num_bins)
# Validate the lengths of x and y.
x_len = len(x)
if x_len < 2:
raise ValueError("len(x) must be at least 2. Got: %s" % x_len)
if x_len != len(y):
raise ValueError("len(x) (got: %d) must equal len(y) (got: %d)" % (x_len,
len(y)))
# Validate x_min and x_max.
x_min = x_min if x_min is not None else x[0]
x_max = x_max if x_max is not None else x[-1]
if x_min >= x_max:
raise ValueError("x_min (got: %d) must be less than x_max (got: %d)" %
(x_min, x_max))
if x_min > x[-1]:
raise ValueError(
"x_min (got: %d) must be less than or equal to the largest value of x "
"(got: %d)" % (x_min, x[-1]))
# Validate bin_width.
bin_width = bin_width if bin_width is not None else (x_max - x_min) / num_bins
if bin_width <= 0:
raise ValueError("bin_width must be positive. Got: %d" % bin_width)
if bin_width >= x_max - x_min:
raise ValueError(
"bin_width (got: %d) must be less than x_max - x_min (got: %d)" %
(bin_width, x_max - x_min))
bin_spacing = (x_max - x_min - bin_width) / (num_bins - 1)
# Bins with no y-values will fall back to the global median.
result = np.repeat(np.median(y), num_bins)
# Find the first element of x >= x_min. This loop is guaranteed to produce
# a valid index because we know that x_min <= x[-1].
x_start = 0
while x[x_start] < x_min:
x_start += 1
# The bin at index i is the median of all elements y[j] such that
# bin_min <= x[j] < bin_max, where bin_min and bin_max are the endpoints of
# bin i.
bin_min = x_min # Left endpoint of the current bin.
bin_max = x_min + bin_width # Right endpoint of the current bin.
j_start = x_start # Inclusive left index of the current bin.
j_end = x_start # Exclusive end index of the current bin.
for i in range(num_bins):
# Move j_start to the first index of x >= bin_min.
while j_start < x_len and x[j_start] < bin_min:
j_start += 1
# Move j_end to the first index of x >= bin_max (exclusive end index).
while j_end < x_len and x[j_end] < bin_max:
j_end += 1
if j_end > j_start:
# Compute and insert the median bin value.
result[i] = np.median(y[j_start:j_end])
# Advance the bin.
bin_min += bin_spacing
bin_max += bin_spacing
return result
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for median_filter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util import median_filter
class MedianFilterTest(absltest.TestCase):
def testErrors(self):
# x size less than 2.
x = [1]
y = [2]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
# x and y not the same size.
x = [1, 2]
y = [4, 5, 6]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
# x_min not less than x_max.
x = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=-1, x_max=-1)
# x_min greater than the last element of x.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=0.25, x_min=3.5, x_max=4)
# bin_width nonpositive.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=0, x_min=1, x_max=3)
# bin_width greater than or equal to x_max - x_min.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=1.5, x_max=2.5)
# num_bins less than 2.
x = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=1, bin_width=1, x_min=0, x_max=2)
def testBucketBoundaries(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([2.5, 4.5, 6.5, 8.5, 10.5], result)
def testMultiSizeBins(self):
# Construct bins with size 0, 1, 2, 3, 4, 5, 10, respectively.
x = np.array([
1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6
])
y = np.array([
0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10
])
result = median_filter.median_filter(
x, y, num_bins=7, bin_width=1, x_min=0, x_max=7)
np.testing.assert_array_equal([3, 0, 0, 5, 3, 1, 5.5], result)
def testMedian(self):
x = np.array([-4, -2, -2, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3])
y = np.array([0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([0, 0, 5, 3, 1], result)
def testWideBins(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=6, x_min=-7, x_max=7)
np.testing.assert_array_equal([3, 4.5, 6.5, 8.5, 10.5], result)
def testNarrowBins(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=1, x_min=-4.5, x_max=4.5)
np.testing.assert_array_equal([3, 5, 7, 9, 11], result)
def testEmptyBins(self):
x = np.array([-1, 0, 1])
y = np.array([1, 2, 3])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([2, 2, 1.5, 3, 2], result)
def testDefaultArgs(self):
x = np.array([-4, -2, -2, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3])
y = np.array([7, -1, 3, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1])
result = median_filter.median_filter(x, y, num_bins=5)
np.testing.assert_array_equal([7, 1, 5, 2, 3], result)
if __name__ == '__main__':
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Event class, which represents a periodic event in a light curve."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class Event(object):
"""Represents a periodic event in a light curve."""
def __init__(self, period, duration, t0):
"""Initializes the Event.
Args:
period: Period of the event, in days.
duration: Duration of the event, in days.
t0: Time of the first occurrence of the event, in days.
"""
self._period = period
self._duration = duration
self._t0 = t0
@property
def period(self):
return self._period
@property
def duration(self):
return self._duration
@property
def t0(self):
return self._t0
def equals(self, other_event, period_rtol=0.001, t0_durations=1):
"""Compares this Event to another Event, within the given tolerance.
Args:
other_event: An Event.
period_rtol: Relative tolerance in matching the periods.
t0_durations: Tolerance in matching the t0 values, in units of the other
Event's duration.
Returns:
True if this Event is the same as other_event, within the given tolerance.
"""
# First compare the periods.
period_match = np.isclose(
self.period, other_event.period, rtol=period_rtol, atol=1e-8)
if not period_match:
return False
# To compare t0, we must consider that self.t0 and other_event.t0 may be at
# different phases. Just comparing mod(self.t0, period) to
# mod(other_event.t0, period) does not work because two similar values could
# end up at different ends of [0, period).
#
# Define t0_diff to be the absolute difference, up to multiples of period.
# This value is always in [0, period/2).
t0_diff = np.mod(self.t0 - other_event.t0, other_event.period)
if t0_diff > other_event.period / 2:
t0_diff = other_event.period - t0_diff
return t0_diff < t0_durations * other_event.duration
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for periodic_event.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from light_curve_util.periodic_event import Event
class EventTest(absltest.TestCase):
def testEquals(self):
event = Event(period=100, duration=5, t0=2)
# Varying periods.
self.assertFalse(event.equals(Event(period=0, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=50, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=99.89, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=99.91, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100.01, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=101, duration=5, t0=2)))
# Different period tolerance.
self.assertTrue(
event.equals(Event(period=99.1, duration=5, t0=2), period_rtol=0.01))
self.assertTrue(
event.equals(Event(period=100.9, duration=5, t0=2), period_rtol=0.01))
self.assertFalse(
event.equals(Event(period=98.9, duration=5, t0=2), period_rtol=0.01))
self.assertFalse(
event.equals(Event(period=101.1, duration=5, t0=2), period_rtol=0.01))
# Varying t0.
self.assertTrue(event.equals(Event(period=100, duration=5, t0=0)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=6.9)))
self.assertFalse(event.equals(Event(period=100, duration=5, t0=7.1)))
# t0 at the other end of [0, period).
self.assertFalse(event.equals(Event(period=100, duration=5, t0=96.9)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=97.1)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=100)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=102)))
self.assertFalse(event.equals(Event(period=100, duration=5, t0=107.1)))
# Varying duration.
self.assertFalse(event.equals(Event(period=100, duration=5, t0=10)))
self.assertFalse(event.equals(Event(period=100, duration=7, t0=10)))
self.assertTrue(event.equals(Event(period=100, duration=9, t0=10)))
# Different duration tolerance.
self.assertFalse(
event.equals(Event(period=100, duration=5, t0=10), t0_durations=1))
self.assertTrue(
event.equals(Event(period=100, duration=5, t0=10), t0_durations=2))
if __name__ == '__main__':
absltest.main()
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Light curve utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import itertools
import numpy as np
from six.moves import range # pylint:disable=redefined-builtin
def phase_fold_time(time, period, t0):
"""Creates a phase-folded time vector.
result[i] is the unique number in [-period / 2, period / 2)
such that result[i] = time[i] - t0 + k_i * period, for some integer k_i.
Args:
time: 1D numpy array of time values.
period: A positive real scalar; the period to fold over.
t0: The center of the resulting folded vector; this value is mapped to 0.
Returns:
A 1D numpy array.
"""
half_period = period / 2
result = np.mod(time + (half_period - t0), period)
result -= half_period
return result
def split(all_time, all_flux, gap_width=0.75):
"""Splits a light curve on discontinuities (gaps).
This function accepts a light curve that is either a single segment, or is
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time
values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux
values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split.
Returns:
out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
# to bool fails if all_time is a numpy array, and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
all_time = [all_time]
all_flux = [all_flux]
out_time = []
out_flux = []
for time, flux in itertools.izip(all_time, all_flux):
start = 0
for end in range(1, len(time) + 1):
# Choose the largest endpoint such that time[start:end] has no gaps.
if end == len(time) or time[end] - time[end - 1] > gap_width:
out_time.append(time[start:end])
out_flux.append(flux[start:end])
start = end
return out_time, out_flux
def remove_events(all_time, all_flux, events, width_factor=1.0):
"""Removes events from a light curve.
This function accepts either a single-segment or piecewise-defined light
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or list of numpy arrays; each is a sequence of time
values.
all_flux: Numpy array or list of numpy arrays; each is a sequence of flux
values of the corresponding time array.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
Returns:
output_time: Numpy array or list of numpy arrays; the time arrays with
events removed.
output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed.
"""
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
# to bool fails if all_time is a numpy array and all_time.size is not defined
# if all_time is a list of numpy arrays.
if len(all_time) > 0 and not isinstance(all_time[0], collections.Iterable): # pylint:disable=g-explicit-length-test
all_time = [all_time]
all_flux = [all_flux]
single_segment = True
else:
single_segment = False
output_time = []
output_flux = []
for time, flux in itertools.izip(all_time, all_flux):
mask = np.ones_like(time, dtype=np.bool)
for event in events:
transit_dist = np.abs(phase_fold_time(time, event.period, event.t0))
mask = np.logical_and(mask,
transit_dist > 0.5 * width_factor * event.duration)
if single_segment:
output_time = time[mask]
output_flux = flux[mask]
else:
output_time.append(time[mask])
output_flux.append(flux[mask])
return output_time, output_flux
def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
"""Linearly interpolates spline values across masked points.
Args:
all_time: List of numpy arrays; each is a sequence of time values.
all_masked_time: List of numpy arrays; each is a sequence of time values
with some values missing (masked).
all_masked_spline: List of numpy arrays; the masked spline values
corresponding to all_masked_time.
Returns:
interp_spline: List of numpy arrays; each is the masked spline with missing
points linearly interpolated.
"""
interp_spline = []
for time, masked_time, masked_spline in itertools.izip(
all_time, all_masked_time, all_masked_spline):
if len(masked_time) > 0: # pylint:disable=g-explicit-length-test
interp_spline.append(np.interp(time, masked_time, masked_spline))
else:
interp_spline.append(np.full_like(time, np.nan))
return interp_spline
def count_transit_points(time, event):
"""Computes the number of points in each transit of a given event.
Args:
time: Sorted numpy array of time values.
event: An Event object.
Returns:
A numpy array containing the number of time points "in transit" for each
transit occurring between the first and last time values.
Raises:
ValueError: If there are more than 10**6 transits.
"""
t_min = np.min(time)
t_max = np.max(time)
# Tiny periods or erroneous time values could make this loop take forever.
if (t_max - t_min) / event.period > 10**6:
raise ValueError(
"Too many transits! Time range is [%.2f, %.2f] and period is %.2e." %
(t_min, t_max, event.period))
# Make sure t0 is in [t_min, t_min + period).
t0 = np.mod(event.t0 - t_min, event.period) + t_min
# Prepare loop variables.
points_in_transit = []
i, j = 0, 0
for transit_midpoint in np.arange(t0, t_max, event.period):
transit_begin = transit_midpoint - event.duration / 2
transit_end = transit_midpoint + event.duration / 2
# Move time[i] to the first point >= transit_begin.
while time[i] < transit_begin:
# transit_begin is guaranteed to be < np.max(t) (provided duration >= 0).
# Therefore, i cannot go out of range.
i += 1
# Move time[j] to the first point > transit_end.
while time[j] <= transit_end:
j += 1
# j went out of range. We're finished.
if j >= len(time):
break
# The points in the current transit duration are precisely time[i:j].
# Since j is an exclusive index, there are exactly j-i points in transit.
points_in_transit.append(j - i)
return np.array(points_in_transit)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util import periodic_event
from light_curve_util import util
class LightCurveUtilTest(absltest.TestCase):
def testPhaseFoldTime(self):
time = np.arange(0, 2, 0.1)
# Simple.
tfold = util.phase_fold_time(time, period=1, t0=0.45)
expected = [
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45,
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45
]
self.assertSequenceAlmostEqual(expected, tfold)
# Large t0.
tfold = util.phase_fold_time(time, period=1, t0=1.25)
expected = [
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35, -0.25,
-0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35
]
self.assertSequenceAlmostEqual(expected, tfold)
# Negative t0.
tfold = util.phase_fold_time(time, period=1, t0=-1.65)
expected = [
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35,
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45
]
self.assertSequenceAlmostEqual(expected, tfold)
# Negative time.
time = np.arange(-3, -1, 0.1)
tfold = util.phase_fold_time(time, period=1, t0=0.55)
expected = [
0.45, -0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45,
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35
]
self.assertSequenceAlmostEqual(expected, tfold)
def testSplit(self):
all_time = [
np.concatenate([
np.arange(0, 1, 0.1),
np.arange(1.5, 2, 0.1),
np.arange(3, 4, 0.1)
])
]
all_flux = [np.array([1] * 25)]
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 3)
self.assertLen(split_flux, 3)
self.assertSequenceAlmostEqual(
[0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], split_time[0])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[0])
self.assertSequenceAlmostEqual([1.5, 1.6, 1.7, 1.8, 1.9], split_time[1])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1], split_flux[1])
self.assertSequenceAlmostEqual(
[3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], split_time[2])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[2])
# Gap width 1.0.
split_time, split_flux = util.split(all_time, all_flux, gap_width=1)
self.assertLen(split_time, 2)
self.assertLen(split_flux, 2)
self.assertSequenceAlmostEqual([
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.5, 1.6, 1.7, 1.8, 1.9
], split_time[0])
self.assertSequenceAlmostEqual(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], split_flux[0])
self.assertSequenceAlmostEqual(
[3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9], split_time[1])
self.assertSequenceAlmostEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
split_flux[1])
def testRemoveEvents(self):
time = np.arange(20, dtype=np.float)
flux = 10 * time
# One event.
events = [periodic_event.Event(period=4, duration=1.5, t0=3.5)]
output_time, output_flux = util.remove_events(time, flux, events)
self.assertSequenceAlmostEqual([1, 2, 5, 6, 9, 10, 13, 14, 17, 18],
output_time)
self.assertSequenceAlmostEqual(
[10, 20, 50, 60, 90, 100, 130, 140, 170, 180], output_flux)
# Two events.
events.append(periodic_event.Event(period=7, duration=1.5, t0=6.5))
output_time, output_flux = util.remove_events(time, flux, events)
self.assertSequenceAlmostEqual([1, 2, 5, 9, 10, 17, 18], output_time)
self.assertSequenceAlmostEqual([10, 20, 50, 90, 100, 170, 180], output_flux)
# Multi segment light curve.
time = [np.arange(10, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
output_time, output_flux = util.remove_events(time, flux, events)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceAlmostEqual([1, 2, 5, 9], output_time[0])
self.assertSequenceAlmostEqual([10, 20, 50, 90], output_flux[0])
self.assertSequenceAlmostEqual([10, 17, 18], output_time[1])
self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1])
def testInterpolateMaskedSpline(self):
all_time = [
np.arange(0, 10, dtype=np.float),
np.arange(10, 20, dtype=np.float),
]
all_masked_time = [
np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7
np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19
]
all_masked_spline = [2 * t + 100 for t in all_masked_time]
interp_spline = util.interpolate_masked_spline(all_time, all_masked_time,
all_masked_spline)
self.assertLen(interp_spline, 2)
self.assertSequenceAlmostEqual(
[100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0])
self.assertSequenceAlmostEqual(
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
def testCountTransitPoints(self):
time = np.concatenate([
np.arange(0, 10, 0.1, dtype=np.float),
np.arange(15, 30, 0.1, dtype=np.float),
np.arange(50, 100, 0.1, dtype=np.float)
])
event = periodic_event.Event(period=10, duration=5, t0=9.95)
points_in_transit = util.count_transit_points(time, event)
np.testing.assert_array_equal([25, 50, 25, 0, 25, 50, 50, 50, 50],
points_in_transit)
if __name__ == "__main__":
absltest.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # MIT
exports_files(["LICENSE"])
py_library(
name = "kepler_spline",
srcs = ["kepler_spline.py"],
srcs_version = "PY2AND3",
deps = ["//third_party/robust_mean"],
)
py_test(
name = "kepler_spline_test",
size = "small",
srcs = ["kepler_spline_test.py"],
srcs_version = "PY2AND3",
deps = [":kepler_spline"],
)
MIT License
Copyright (c) 2017 avanderburg
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
"""Functions for computing normalization splines for Kepler light curves."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import warnings
import numpy as np
from pydl.pydlutils import bspline
from third_party.robust_mean import robust_mean
class SplineError(Exception):
"""Error when fitting a Kepler spline."""
pass
def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
"""Computes a best-fit spline curve for a light curve segment.
The spline is fit using an iterative process to remove outliers that may cause
the spline to be "pulled" by discrepent points. In each iteration the spline
is fit, and if there are any points where the absolute deviation from the
median residual is at least 3*sigma (where sigma is a robust estimate of the
standard deviation of the residuals), those points are removed and the spline
is re-fit.
Args:
time: Numpy array; the time values of the light curve.
flux: Numpy array; the flux (brightness) values of the light curve.
bkspace: Spline break point spacing in time units.
maxiter: Maximum number of attempts to fit the spline after removing badly
fit points.
outlier_cut: The maximum number of standard deviations from the median
spline residual before a point is considered an outlier.
Returns:
spline: The values of the fitted spline corresponding to the input time
values.
mask: Boolean mask indicating the points used to fit the final spline.
"""
# Rescale time into [0, 1].
t_min = np.min(time)
t_max = np.max(time)
time = (time - t_min) / (t_max - t_min)
bkspace /= (t_max - t_min) # Rescale bucket spacing.
# Values of the best fitting spline evaluated at the time points.
spline = None
# Mask indicating the points used to fit the spline.
mask = None
for _ in range(maxiter):
if spline is None:
mask = np.ones_like(time, dtype=np.bool) # Try to fit all points.
else:
# Choose points where the absolute deviation from the median residual is
# less than 3*sigma, where sigma is a robust estimate of the standard
# deviation of the residuals from the previous spline.
residuals = flux - spline
_, _, new_mask = robust_mean.robust_mean(residuals, cut=outlier_cut)
if np.all(new_mask == mask):
break # Spline converged.
mask = new_mask
try:
with warnings.catch_warnings():
# Suppress warning messages printed by pydlutils.bspline. Instead we
# catch any exception and raise a more informative error.
warnings.simplefilter("ignore")
# Fit the spline on non-outlier points.
curve = bspline.iterfit(time[mask], flux[mask], bkspace=bkspace)[0]
# Evaluate spline at the time points.
spline = curve.value(time)[0]
except (IndexError, TypeError) as e:
raise SplineError(
"Fitting spline failed with error: '%s'. This might be caused by the "
"breakpoint spacing being too small, and/or there being insufficient "
"points to fit the spline in one of the intervals." % e)
return spline, mask
def choose_kepler_spline(all_time,
all_flux,
bkspaces,
maxiter=5,
penalty_coeff=1.0,
verbose=True):
"""Computes the best-fit Kepler spline across a break-point spacings.
Some Kepler light curves have low-frequency variability, while others have
very high-frequency variability (e.g. due to rapid rotation). Therefore, it is
suboptimal to use the same break-point spacing for every star. This function
computes the best-fit spline by fitting splines with different break-point
spacings, calculating the Bayesian Information Criterion (BIC) for each
spline, and choosing the break-point spacing that minimizes the BIC.
This function assumes a piecewise light curve, that is, a light curve that is
divided into different segments (e.g. split by quarter breaks or gaps in the
in the data). A separate spline is fit for each segment.
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux (brightness) values of the light
curve.
bkspaces: List of break-point spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
penalty_coeff: Coefficient of the penalty term for using more parameters in
the Bayesian Information Criterion. Decreasing this value will allow
more parameters to be used (i.e. smaller break-point spacing), and
vice-versa.
verbose: Whether to log individual spline errors. Note that if bkspaces
contains many values (particularly small ones) then this may cause
logging pollution if calling this function for many light curves.
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
spline_mask: List of boolean numpy arrays indicating which points in the
flux arrays were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
"""
# Compute the assumed standard deviation of Gaussian white noise about the
# spline model.
abs_deviations = np.concatenate([np.abs(f[1:] - f[:-1]) for f in all_flux])
sigma = np.median(abs_deviations) * 1.48 / np.sqrt(2)
best_bic = None
best_spline = None
best_spline_mask = None
best_bkspace = None
bad_bkspaces = []
for bkspace in bkspaces:
nparams = 0 # Total number of free parameters in the piecewise spline.
npoints = 0 # Total number of data points used to fit the piecewise spline.
ssr = 0 # Sum of squared residuals between the model and the spline.
spline = []
spline_mask = []
bad_bkspace = False # Indicates that the current bkspace should be skipped.
for time, flux in itertools.izip(all_time, all_flux):
# Don't fit a spline on less than 4 points.
if len(time) < 4:
spline.append(flux)
spline_mask.append(np.ones_like(flux), dtype=np.bool)
continue
# Fit B-spline to this light-curve segment.
try:
spline_piece, mask = kepler_spline(
time, flux, bkspace=bkspace, maxiter=maxiter)
# It's expected to get a SplineError occasionally for small values of
# bkspace.
except SplineError as e:
if verbose:
warnings.warn("Bad bkspace %.4f: %s" % (bkspace, e))
bad_bkspaces.append(bkspace)
bad_bkspace = True
break
spline.append(spline_piece)
spline_mask.append(mask)
# Accumulate the number of free parameters.
total_time = np.max(time) - np.min(time)
nknots = int(total_time / bkspace) + 1 # From the bspline implementation.
nparams += nknots + 3 - 1 # number of knots + degree of spline - 1
# Accumulate the number of points and the squared residuals.
npoints += np.sum(mask)
ssr += np.sum((flux[mask] - spline_piece[mask])**2)
if bad_bkspace:
continue
# The following term is -2*ln(L), where L is the likelihood of the data
# given the model, under the assumption that the model errors are iid
# Gaussian with mean 0 and standard deviation sigma.
likelihood_term = npoints * np.log(2 * np.pi * sigma**2) + ssr / sigma**2
# Bayesian information criterion.
bic = likelihood_term + penalty_coeff * nparams * np.log(npoints)
if best_bic is None or bic < best_bic:
best_bic = bic
best_spline = spline
best_spline_mask = spline_mask
best_bkspace = bkspace
return best_spline, best_spline_mask, best_bkspace, bad_bkspaces
"""Tests for kepler_spline.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from third_party.kepler_spline import kepler_spline
class KeplerSplineTest(absltest.TestCase):
def testKeplerSplineSine(self):
# Fit a sine wave.
time = np.arange(0, 10, 0.1)
flux = np.sin(time)
# Expect very close fit with no outliers removed.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=0.5)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-4)
self.assertTrue(np.all(mask))
# Add some outliers.
flux[35] = 10
flux[77] = -3
flux[95] = 2.9
# Expect a close fit with outliers removed.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=0.5)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-4)
self.assertEqual(np.sum(mask), 97)
self.assertFalse(mask[35])
self.assertFalse(mask[77])
self.assertFalse(mask[95])
# Increase breakpoint spacing. Fit is not quite as close.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=1)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 2e-3)
self.assertEqual(np.sum(mask), 97)
self.assertFalse(mask[35])
self.assertFalse(mask[77])
self.assertFalse(mask[95])
def testKeplerSplineCubic(self):
# Fit a cubic polynomial.
time = np.arange(0, 10, 0.1)
flux = (time - 5)**3 + 2 * (time - 5)**2 + 10
# Expect very close fit with no outliers removed. We choose maxiter=1,
# because a cubic spline will fit a cubic polynomial ~exactly, so the
# standard deviation of residuals will be ~0, which will cause some closely
# fit points to be rejected.
spline, mask = kepler_spline.kepler_spline(
time, flux, bkspace=0.5, maxiter=1)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-12)
self.assertTrue(np.all(mask))
def testKeplerSplineError(self):
# Big gap.
time = np.concatenate([np.arange(0, 1, 0.1), [2]])
flux = np.sin(time)
with self.assertRaises(kepler_spline.SplineError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5)
def testChooseKeplerSpline(self):
# High frequency sine wave.
time = [np.arange(0, 100, 0.1), np.arange(100, 200, 0.1)]
flux = [np.sin(t) for t in time]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
def _rmse(all_flux, all_spline):
f = np.concatenate(all_flux)
s = np.concatenate(all_spline)
return np.sqrt(np.mean((f - s)**2))
# Penalty coefficient 1.0.
spline, mask, bkspace, bad_bkspaces = kepler_spline.choose_kepler_spline(
time, flux, bkspaces, penalty_coeff=1.0)
self.assertAlmostEqual(_rmse(flux, spline), 0.013013)
self.assertTrue(np.all(mask))
self.assertAlmostEqual(bkspace, 1.67990914314)
self.assertEmpty(bad_bkspaces)
# Decrease penalty coefficient; allow smaller spacing for closer fit.
spline, mask, bkspace, bad_bkspaces = kepler_spline.choose_kepler_spline(
time, flux, bkspaces, penalty_coeff=0.1)
self.assertAlmostEqual(_rmse(flux, spline), 0.0066376)
self.assertTrue(np.all(mask))
self.assertAlmostEqual(bkspace, 1.48817572082)
self.assertEmpty(bad_bkspaces)
# Increase penalty coefficient; require larger spacing at the cost of worse
# fit.
spline, mask, bkspace, bad_bkspaces = kepler_spline.choose_kepler_spline(
time, flux, bkspaces, penalty_coeff=2)
self.assertAlmostEqual(_rmse(flux, spline), 0.026215449)
self.assertTrue(np.all(mask))
self.assertAlmostEqual(bkspace, 1.89634509537)
self.assertEmpty(bad_bkspaces)
if __name__ == "__main__":
absltest.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # BSD
exports_files(["LICENSE"])
py_library(
name = "robust_mean",
srcs = ["robust_mean.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "robust_mean_test",
size = "small",
srcs = [
"robust_mean_test.py",
"test_data/random_normal.py",
],
srcs_version = "PY2AND3",
deps = [":robust_mean"],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment