Unverified Commit 2c181308 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5862 from cshallue/master

Move tensorflow_models/research/astronet to google-research/exoplanet-ml
parents caafb6d1 62704f06
MIT License
Copyright (c) 2017 avanderburg
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
"""Functions for computing normalization splines for Kepler light curves."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import numpy as np
from pydl.pydlutils import bspline
from third_party.robust_mean import robust_mean
class InsufficientPointsError(Exception):
"""Indicates that insufficient points were available for spline fitting."""
pass
class SplineError(Exception):
"""Indicates an error in the underlying spline-fitting implementation."""
pass
def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
"""Computes a best-fit spline curve for a light curve segment.
The spline is fit using an iterative process to remove outliers that may cause
the spline to be "pulled" by discrepent points. In each iteration the spline
is fit, and if there are any points where the absolute deviation from the
median residual is at least 3*sigma (where sigma is a robust estimate of the
standard deviation of the residuals), those points are removed and the spline
is re-fit.
Args:
time: Numpy array; the time values of the light curve.
flux: Numpy array; the flux (brightness) values of the light curve.
bkspace: Spline break point spacing in time units.
maxiter: Maximum number of attempts to fit the spline after removing badly
fit points.
outlier_cut: The maximum number of standard deviations from the median
spline residual before a point is considered an outlier.
Returns:
spline: The values of the fitted spline corresponding to the input time
values.
mask: Boolean mask indicating the points used to fit the final spline.
Raises:
InsufficientPointsError: If there were insufficient points (after removing
outliers) for spline fitting.
SplineError: If the spline could not be fit, for example if the breakpoint
spacing is too small.
"""
if len(time) < 4:
raise InsufficientPointsError(
"Cannot fit a spline on less than 4 points. Got {} points.".format(
len(time)))
# Rescale time into [0, 1].
t_min = np.min(time)
t_max = np.max(time)
time = (time - t_min) / (t_max - t_min)
bkspace /= (t_max - t_min) # Rescale bucket spacing.
# Values of the best fitting spline evaluated at the time points.
spline = None
# Mask indicating the points used to fit the spline.
mask = None
for _ in range(maxiter):
if spline is None:
mask = np.ones_like(time, dtype=np.bool) # Try to fit all points.
else:
# Choose points where the absolute deviation from the median residual is
# less than outlier_cut*sigma, where sigma is a robust estimate of the
# standard deviation of the residuals from the previous spline.
residuals = flux - spline
new_mask = robust_mean.robust_mean(residuals, cut=outlier_cut)[2]
if np.all(new_mask == mask):
break # Spline converged.
mask = new_mask
if np.sum(mask) < 4:
# Fewer than 4 points after removing outliers. We could plausibly return
# the spline from the previous iteration because it was fit with at least
# 4 points. However, since the outliers were such a significant fraction
# of the curve, the spline from the previous iteration is probably junk,
# and we consider this a fatal error.
raise InsufficientPointsError(
"Cannot fit a spline on less than 4 points. After removing "
"outliers, got {} points.".format(np.sum(mask)))
try:
with warnings.catch_warnings():
# Suppress warning messages printed by pydlutils.bspline. Instead we
# catch any exception and raise a more informative error.
warnings.simplefilter("ignore")
# Fit the spline on non-outlier points.
curve = bspline.iterfit(time[mask], flux[mask], bkspace=bkspace)[0]
# Evaluate spline at the time points.
spline = curve.value(time)[0]
except (IndexError, TypeError) as e:
raise SplineError(
"Fitting spline failed with error: '{}'. This might be caused by the "
"breakpoint spacing being too small, and/or there being insufficient "
"points to fit the spline in one of the intervals.".format(e))
return spline, mask
class SplineMetadata(object):
"""Metadata about a spline fit.
Attributes:
light_curve_mask: List of boolean numpy arrays indicating which points in
the light curve were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
likelihood_term: The likelihood term of the Bayesian Information Criterion;
-2*ln(L), where L is the likelihood of the data given the model.
penalty_term: The penalty term for the number of parameters in the Bayesian
Information Criterion.
bic: The value of the Bayesian Information Criterion; equal to
likelihood_term + penalty_coeff * penalty_term.
"""
def __init__(self):
self.light_curve_mask = None
self.bkspace = None
self.bad_bkspaces = []
self.likelihood_term = None
self.penalty_term = None
self.bic = None
def choose_kepler_spline(all_time,
all_flux,
bkspaces,
maxiter=5,
penalty_coeff=1.0,
verbose=True):
"""Computes the best-fit Kepler spline across a break-point spacings.
Some Kepler light curves have low-frequency variability, while others have
very high-frequency variability (e.g. due to rapid rotation). Therefore, it is
suboptimal to use the same break-point spacing for every star. This function
computes the best-fit spline by fitting splines with different break-point
spacings, calculating the Bayesian Information Criterion (BIC) for each
spline, and choosing the break-point spacing that minimizes the BIC.
This function assumes a piecewise light curve, that is, a light curve that is
divided into different segments (e.g. split by quarter breaks or gaps in the
in the data). A separate spline is fit for each segment.
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux values of the light curve.
bkspaces: List of break-point spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
penalty_coeff: Coefficient of the penalty term for using more parameters in
the Bayesian Information Criterion. Decreasing this value will allow more
parameters to be used (i.e. smaller break-point spacing), and vice-versa.
verbose: Whether to log individual spline errors. Note that if bkspaces
contains many values (particularly small ones) then this may cause logging
pollution if calling this function for many light curves.
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
metadata: Object containing metadata about the spline fit.
"""
# Initialize outputs.
best_spline = None
metadata = SplineMetadata()
# Compute the assumed standard deviation of Gaussian white noise about the
# spline model. We assume that each flux value f[i] is a Gaussian random
# variable f[i] ~ N(s[i], sigma^2), where s is the value of the true spline
# model and sigma is the constant standard deviation for all flux values.
# Moreover, we assume that s[i] ~= s[i+1]. Therefore,
# (f[i+1] - f[i]) / sqrt(2) ~ N(0, sigma^2).
scaled_diffs = [np.diff(f) / np.sqrt(2) for f in all_flux]
scaled_diffs = np.concatenate(scaled_diffs) if scaled_diffs else np.array([])
if not scaled_diffs.size:
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
np.zeros_like(f, dtype=np.bool) for f in all_flux
]
return best_spline, metadata
# Compute the median absoute deviation as a robust estimate of sigma. The
# conversion factor of 1.48 takes the median absolute deviation to the
# standard deviation of a normal distribution. See, e.g.
# https://www.mathworks.com/help/stats/mad.html.
sigma = np.median(np.abs(scaled_diffs)) * 1.48
for bkspace in bkspaces:
nparams = 0 # Total number of free parameters in the piecewise spline.
npoints = 0 # Total number of data points used to fit the piecewise spline.
ssr = 0 # Sum of squared residuals between the model and the spline.
spline = []
light_curve_mask = []
bad_bkspace = False # Indicates that the current bkspace should be skipped.
for time, flux in zip(all_time, all_flux):
# Fit B-spline to this light-curve segment.
try:
spline_piece, mask = kepler_spline(
time, flux, bkspace=bkspace, maxiter=maxiter)
except InsufficientPointsError as e:
# It's expected to occasionally see intervals with insufficient points,
# especially if periodic signals have been removed from the light curve.
# Skip this interval, but continue fitting the spline.
if verbose:
warnings.warn(str(e))
spline.append(np.array([np.nan] * len(flux)))
light_curve_mask.append(np.zeros_like(flux, dtype=np.bool))
continue
except SplineError as e:
# It's expected to get a SplineError occasionally for small values of
# bkspace. Skip this bkspace.
if verbose:
warnings.warn("Bad bkspace {}: {}".format(bkspace, e))
metadata.bad_bkspaces.append(bkspace)
bad_bkspace = True
break
spline.append(spline_piece)
light_curve_mask.append(mask)
# Accumulate the number of free parameters.
total_time = np.max(time) - np.min(time)
nknots = int(total_time / bkspace) + 1 # From the bspline implementation.
nparams += nknots + 3 - 1 # number of knots + degree of spline - 1
# Accumulate the number of points and the squared residuals.
npoints += np.sum(mask)
ssr += np.sum((flux[mask] - spline_piece[mask])**2)
if bad_bkspace or not npoints:
continue
# The following term is -2*ln(L), where L is the likelihood of the data
# given the model, under the assumption that the model errors are iid
# Gaussian with mean 0 and standard deviation sigma.
likelihood_term = npoints * np.log(2 * np.pi * sigma**2) + ssr / sigma**2
# Penalty term for the number of parameters used to fit the model.
penalty_term = nparams * np.log(npoints)
# Bayesian information criterion.
bic = likelihood_term + penalty_coeff * penalty_term
if best_spline is None or bic < metadata.bic:
best_spline = spline
metadata.light_curve_mask = light_curve_mask
metadata.bkspace = bkspace
metadata.likelihood_term = likelihood_term
metadata.penalty_term = penalty_term
metadata.bic = bic
if best_spline is None:
# All bkspaces resulted in a SplineError, or all light curve intervals had
# insufficient points.
best_spline = [np.array([np.nan] * len(f)) for f in all_flux]
metadata.light_curve_mask = [
np.zeros_like(f, dtype=np.bool) for f in all_flux
]
return best_spline, metadata
def fit_kepler_spline(all_time,
all_flux,
bkspace_min=0.5,
bkspace_max=20,
bkspace_num=20,
maxiter=5,
penalty_coeff=1.0,
verbose=True):
"""Fits a Kepler spline with logarithmically-sampled breakpoint spacings.
Args:
all_time: List of 1D numpy arrays; the time values of the light curve.
all_flux: List of 1D numpy arrays; the flux values of the light curve.
bkspace_min: Minimum breakpoint spacing to try.
bkspace_max: Maximum breakpoint spacing to try.
bkspace_num: Number of breakpoint spacings to try.
maxiter: Maximum number of attempts to fit each spline after removing badly
fit points.
penalty_coeff: Coefficient of the penalty term for using more parameters in
the Bayesian Information Criterion. Decreasing this value will allow more
parameters to be used (i.e. smaller break-point spacing), and vice-versa.
verbose: Whether to log individual spline errors. Note that if bkspaces
contains many values (particularly small ones) then this may cause logging
pollution if calling this function for many light curves.
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
metadata: Object containing metadata about the spline fit.
"""
# Logarithmically sample bkspace_num candidate break point spacings between
# bkspace_min and bkspace_max.
bkspaces = np.logspace(
np.log10(bkspace_min), np.log10(bkspace_max), num=bkspace_num)
return choose_kepler_spline(
all_time,
all_flux,
bkspaces,
maxiter=maxiter,
penalty_coeff=penalty_coeff,
verbose=verbose)
"""Tests for kepler_spline.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from third_party.kepler_spline import kepler_spline
class KeplerSplineTest(absltest.TestCase):
def testFitSine(self):
# Fit a sine wave.
time = np.arange(0, 10, 0.1)
flux = np.sin(time)
# Expect very close fit with no outliers removed.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=0.5)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-4)
self.assertTrue(np.all(mask))
# Add some outliers.
flux[35] = 10
flux[77] = -3
flux[95] = 2.9
# Expect a close fit with outliers removed.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=0.5)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-4)
self.assertEqual(np.sum(mask), 97)
self.assertFalse(mask[35])
self.assertFalse(mask[77])
self.assertFalse(mask[95])
# Increase breakpoint spacing. Fit is not quite as close.
spline, mask = kepler_spline.kepler_spline(time, flux, bkspace=1)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 2e-3)
self.assertEqual(np.sum(mask), 97)
self.assertFalse(mask[35])
self.assertFalse(mask[77])
self.assertFalse(mask[95])
def testFitCubic(self):
# Fit a cubic polynomial.
time = np.arange(0, 10, 0.1)
flux = (time - 5)**3 + 2 * (time - 5)**2 + 10
# Expect very close fit with no outliers removed. We choose maxiter=1,
# because a cubic spline will fit a cubic polynomial ~exactly, so the
# standard deviation of residuals will be ~0, which will cause some closely
# fit points to be rejected.
spline, mask = kepler_spline.kepler_spline(
time, flux, bkspace=0.5, maxiter=1)
rmse = np.sqrt(np.mean((flux[mask] - spline[mask])**2))
self.assertLess(rmse, 1e-12)
self.assertTrue(np.all(mask))
def testInsufficientPointsError(self):
# Empty light curve.
time = np.array([])
flux = np.array([])
with self.assertRaises(kepler_spline.InsufficientPointsError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5)
# Only 3 points.
time = np.array([0.1, 0.2, 0.3])
flux = np.sin(time)
with self.assertRaises(kepler_spline.InsufficientPointsError):
kepler_spline.kepler_spline(time, flux, bkspace=0.5)
class ChooseKeplerSplineTest(absltest.TestCase):
def testEmptyInput(self):
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time=[],
all_flux=[],
bkspaces=bkspaces,
penalty_coeff=1.0,
verbose=False)
np.testing.assert_array_equal(spline, [])
np.testing.assert_array_equal(metadata.light_curve_mask, [])
def testNoPoints(self):
all_time = [np.array([])]
all_flux = [np.array([])]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
np.testing.assert_array_equal(spline, [[]])
np.testing.assert_array_equal(metadata.light_curve_mask, [[]])
def testTooFewPoints(self):
# Sine wave with segments of 1, 2, 3 points.
all_time = [
np.array([0.1]),
np.array([0.2, 0.3]),
np.array([0.4, 0.5, 0.6])
]
all_flux = [np.sin(t) for t in all_time]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
# All segments are NaN.
self.assertTrue(np.all(np.isnan(np.concatenate(spline))))
self.assertFalse(np.any(np.concatenate(metadata.light_curve_mask)))
self.assertIsNone(metadata.bkspace)
self.assertEmpty(metadata.bad_bkspaces)
self.assertIsNone(metadata.likelihood_term)
self.assertIsNone(metadata.penalty_term)
self.assertIsNone(metadata.bic)
# Add a longer segment.
all_time.append(np.arange(0.7, 2.0, 0.1))
all_flux.append(np.sin(all_time[-1]))
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)
# First 3 segments are NaN.
for i in range(3):
self.assertTrue(np.all(np.isnan(spline[i])))
self.assertFalse(np.any(metadata.light_curve_mask[i]))
# Final segment is a good fit.
self.assertTrue(np.all(np.isfinite(spline[3])))
self.assertTrue(np.all(metadata.light_curve_mask[3]))
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -58.0794069927957)
self.assertAlmostEqual(metadata.penalty_term, 7.69484807238461)
self.assertAlmostEqual(metadata.bic, -50.3845589204111)
def testFitSine(self):
# High frequency sine wave.
all_time = [np.arange(0, 100, 0.1), np.arange(100, 200, 0.1)]
all_flux = [np.sin(t) for t in all_time]
# Logarithmically sample candidate break point spacings.
bkspaces = np.logspace(np.log10(0.5), np.log10(5), num=20)
def _rmse(all_flux, all_spline):
f = np.concatenate(all_flux)
s = np.concatenate(all_spline)
return np.sqrt(np.mean((f - s)**2))
# Penalty coefficient 1.0.
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0)
self.assertAlmostEqual(_rmse(all_flux, spline), 0.013013)
self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(metadata.bkspace, 1.67990914314)
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6685.64217856480)
self.assertAlmostEqual(metadata.penalty_term, 942.51190498322)
self.assertAlmostEqual(metadata.bic, -5743.13027358158)
# Decrease penalty coefficient; allow smaller spacing for closer fit.
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=0.1)
self.assertAlmostEqual(_rmse(all_flux, spline), 0.0066376)
self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(metadata.bkspace, 1.48817572082)
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6731.59913975551)
self.assertAlmostEqual(metadata.penalty_term, 1064.12634433589)
self.assertAlmostEqual(metadata.bic, -6625.18650532192)
# Increase penalty coefficient; require larger spacing at the cost of worse
# fit.
spline, metadata = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=2)
self.assertAlmostEqual(_rmse(all_flux, spline), 0.026215449)
self.assertTrue(np.all(metadata.light_curve_mask))
self.assertAlmostEqual(metadata.bkspace, 1.89634509537)
self.assertEmpty(metadata.bad_bkspaces)
self.assertAlmostEqual(metadata.likelihood_term, -6495.65564287904)
self.assertAlmostEqual(metadata.penalty_term, 836.099270549629)
self.assertAlmostEqual(metadata.bic, -4823.45710177978)
if __name__ == "__main__":
absltest.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # BSD
exports_files(["LICENSE"])
py_library(
name = "robust_mean",
srcs = ["robust_mean.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "robust_mean_test",
size = "small",
srcs = [
"robust_mean_test.py",
"test_data/random_normal.py",
],
srcs_version = "PY2AND3",
deps = [":robust_mean"],
)
Copyright (c) 2014, Wayne Landsman
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
"""Function for computing a robust mean estimate in the presence of outliers.
This is a modified Python implementation of this file:
https://idlastro.gsfc.nasa.gov/ftp/pro/robust/resistant_mean.pro
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def robust_mean(y, cut):
"""Computes a robust mean estimate in the presence of outliers.
Args:
y: 1D numpy array. Assumed to be normally distributed with outliers.
cut: Points more than this number of standard deviations from the median are
ignored.
Returns:
mean: A robust estimate of the mean of y.
mean_stddev: The standard deviation of the mean.
mask: Boolean array with the same length as y. Values corresponding to
outliers in y are False. All other values are True.
"""
# First, make a robust estimate of the standard deviation of y, assuming y is
# normally distributed. The conversion factor of 1.4826 takes the median
# absolute deviation to the standard deviation of a normal distribution.
# See, e.g. https://www.mathworks.com/help/stats/mad.html.
absdev = np.abs(y - np.median(y))
sigma = 1.4826 * np.median(absdev)
# If the previous estimate of the standard deviation using the median absolute
# deviation is zero, fall back to a robust estimate using the mean absolute
# deviation. This estimator has a different conversion factor of 1.253.
# See, e.g. https://www.mathworks.com/help/stats/mad.html.
if sigma < 1.0e-24:
sigma = 1.253 * np.mean(absdev)
# Identify outliers using our estimate of the standard deviation of y.
mask = absdev <= cut * sigma
# Now, recompute the standard deviation, using the sample standard deviation
# of non-outlier points.
sigma = np.std(y[mask])
# Compensate the estimate of sigma due to trimming away outliers. The
# following formula is an approximation, see
# http://w.astro.berkeley.edu/~johnjohn/idlprocs/robust_mean.pro.
sc = np.max([cut, 1.0])
if sc <= 4.5:
sigma /= (-0.15405 + 0.90723 * sc - 0.23584 * sc**2 + 0.020142 * sc**3)
# Identify outliers using our second estimate of the standard deviation of y.
mask = absdev <= cut * sigma
# Now, recompute the standard deviation, using the sample standard deviation
# with non-outlier points.
sigma = np.std(y[mask])
# Compensate the estimate of sigma due to trimming away outliers.
sc = np.max([cut, 1.0])
if sc <= 4.5:
sigma /= (-0.15405 + 0.90723 * sc - 0.23584 * sc**2 + 0.020142 * sc**3)
# Final estimate is the sample mean with outliers removed.
mean = np.mean(y[mask])
mean_stddev = sigma / np.sqrt(len(y) - 1.0)
return mean, mean_stddev, mask
"""Tests for robust_mean.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from third_party.robust_mean import robust_mean
from third_party.robust_mean.test_data import random_normal
class RobustMeanTest(absltest.TestCase):
def testRobustMean(self):
# To avoid non-determinism in the unit test, we use a pre-generated vector
# of length 1,000. Each entry is independently sampled from a random normal
# distribution with mean 2 and standard deviation 1. The maximum value of
# y is 6.075 (+4.075 sigma from the mean) and the minimum value is -1.54
# (-3.54 sigma from the mean).
y = np.array(random_normal.RANDOM_NORMAL)
self.assertAlmostEqual(np.mean(y), 2.00336615850485)
self.assertAlmostEqual(np.std(y), 1.01690907798)
# High cut. No points rejected, so the mean should be the sample mean, and
# the mean standard deviation should be the sample standard deviation
# divided by sqrt(1000 - 1).
mean, mean_stddev, mask = robust_mean.robust_mean(y, cut=5)
self.assertAlmostEqual(mean, 2.00336615850485)
self.assertAlmostEqual(mean_stddev, 0.032173579)
self.assertLen(mask, 1000)
self.assertEqual(np.sum(mask), 1000)
# Cut of 3 standard deviations.
mean, mean_stddev, mask = robust_mean.robust_mean(y, cut=3)
self.assertAlmostEqual(mean, 2.0059050070632178)
self.assertAlmostEqual(mean_stddev, 0.03197075302321066)
# There are exactly 3 points in the sample less than 1 or greater than 5.
# These have indices 12, 220, 344.
self.assertLen(mask, 1000)
self.assertEqual(np.sum(mask), 997)
self.assertFalse(np.any(mask[[12, 220, 344]]))
# Add outliers. This corrupts the sample mean to 2.082.
mean, mean_stddev, mask = robust_mean.robust_mean(
y=np.concatenate([y, [10] * 10]), cut=5)
self.assertAlmostEqual(mean, 2.0033661585048681)
self.assertAlmostEqual(mean_stddev, 0.032013749413590531)
self.assertLen(mask, 1010)
self.assertEqual(np.sum(mask), 1000)
self.assertFalse(np.any(mask[1000:1010]))
# Add an outlier. This corrupts the mean to 1.002.
mean, mean_stddev, mask = robust_mean.robust_mean(
y=np.concatenate([y, [-1000]]), cut=5)
self.assertAlmostEqual(mean, 2.0033661585048681)
self.assertAlmostEqual(mean_stddev, 0.032157488597211903)
self.assertLen(mask, 1001)
self.assertEqual(np.sum(mask), 1000)
self.assertFalse(mask[1000])
if __name__ == "__main__":
absltest.main()
"""This file contains 1,000 points of a random normal distribution.
The mean of the distribution is 2, and the standard deviation is 1.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
RANDOM_NORMAL = [
0.692741320869,
1.948556207658,
-0.140158117639,
2.680322906859,
1.876671492867,
2.885286232509,
1.482222151802,
2.234349266246,
2.437427989583,
1.573053624952,
2.169198249367,
2.791023059619,
-1.053798286951,
1.796497126664,
3.806070621390,
1.744055208958,
3.474399140181,
1.564447665560,
1.143107137921,
1.618376255615,
2.615632139609,
1.413239777404,
1.047237320108,
3.190489536636,
2.918428435434,
1.268789896280,
0.931181066003,
3.797790627792,
0.493025834330,
1.866146169585,
0.949927834893,
1.439666857958,
2.705521702500,
1.815406073907,
1.570503841718,
1.834429337005,
2.903916580263,
-0.110549195467,
2.065338922749,
1.119498053048,
0.427627428035,
3.025052175045,
2.645448868784,
1.442644951218,
0.774681298962,
2.247561418494,
1.743438974941,
1.184440017832,
1.643691193885,
1.947748675186,
2.178309991836,
2.815355272672,
2.207620544168,
2.077889048169,
2.915504366132,
2.440862146850,
2.804729838623,
0.534712595625,
1.956491042766,
2.230542009671,
2.186536281651,
3.694129968231,
3.313526598170,
2.170240599444,
2.793531289796,
1.454464312809,
1.197463589804,
0.713332299712,
1.180965411999,
2.180022174106,
2.861107927091,
1.795223865106,
1.730056153040,
1.431404424890,
1.839372935334,
1.271871740741,
3.773103777671,
1.026069424885,
2.006070770486,
1.276836142291,
1.414098998873,
1.749117068374,
2.040006827147,
1.815581326626,
2.892666735522,
3.093934769003,
2.129166907135,
1.260521633663,
3.259431640120,
1.879415647487,
1.368769201985,
2.236653714367,
2.293120875655,
2.361086097355,
2.140675892497,
2.860288793716,
3.109921655205,
2.142509743586,
0.661829413359,
0.620852115030,
2.279817287885,
2.077609300700,
1.917031492891,
2.549328729021,
1.402961147881,
2.989802645752,
2.126646549508,
0.581285045065,
3.226987223858,
1.790860716921,
0.998661497130,
2.125771640271,
2.186096892741,
2.160189267804,
2.206460323846,
3.366179111195,
-0.125206283025,
0.645228886619,
0.505553980622,
4.494406059555,
1.291690417806,
2.977896904657,
2.869240282824,
3.344192278881,
2.487041683297,
4.236730343795,
3.007206122800,
1.210065291965,
-0.053847768077,
1.108953782402,
1.843857008095,
2.374767801329,
1.472199059501,
3.332198116275,
2.027084082885,
2.305065331530,
3.387400013580,
1.493365795517,
2.344295515065,
2.898632740793,
3.307836869328,
1.892766317783,
2.348033912288,
1.288522200888,
2.178559140529,
2.366037265891,
3.468023805733,
1.910134543982,
1.750500687923,
1.506717073807,
1.345976221745,
1.898226480175,
2.362688287820,
2.176558673313,
1.716475335783,
1.109563102324,
1.824697060483,
2.290331853365,
3.660496355225,
3.695990930547,
0.995131810353,
2.083740307542,
2.515409175245,
1.734919119633,
0.186488629263,
3.470910728743,
3.503515673097,
2.225335667636,
4.925211524431,
3.176405299532,
2.938260408825,
2.336603901159,
2.218333712640,
3.269148549824,
1.921171637456,
3.876114839719,
1.492216718705,
2.792835112200,
3.563198188748,
2.728530961520,
3.231549893645,
2.209018339760,
1.081828242171,
0.754161622090,
1.948018149260,
2.413945024183,
1.425023717183,
2.005406706788,
0.964987890314,
1.603414847296,
0.132077263346,
1.789327371404,
1.423488299029,
2.590160851192,
3.131340836085,
2.325779171436,
2.129789552692,
1.876126153813,
2.667783873354,
-0.220464828097,
2.285158851436,
1.188664672684,
1.968980968179,
2.510328726654,
1.690300427857,
2.041495293673,
2.471293710293,
1.660589811070,
1.801640276851,
2.200864460731,
1.489583958038,
1.545725376492,
4.208130184998,
2.428489533380,
3.539990060815,
1.317090333595,
0.785936916712,
0.809688718378,
1.265062896735,
2.749291333938,
6.075297866258,
2.165845459075,
2.055273600728,
2.584618009430,
2.782654850307,
0.967100649409,
2.267394795463,
2.783350629984,
0.238340558296,
1.566536380829,
1.165403279885,
3.409015124349,
1.047853632456,
2.100798231132,
1.824776518459,
1.517825551662,
2.148972385365,
1.818426298006,
1.954355115973,
2.428393037760,
2.225660788849,
1.287880002052,
3.083900598687,
2.561457835470,
2.547146477110,
-0.060868513691,
1.917876348341,
1.194823858275,
1.237685798924,
2.500081029116,
0.605823016300,
1.341027488293,
1.357719149407,
3.959221361786,
1.457342301661,
1.450552596247,
3.152966485077,
1.755910034199,
2.252303064393,
2.315145292843,
2.092889154866,
2.044536701039,
3.078226379252,
1.940374989780,
0.981160719305,
1.801484599888,
4.599412580952,
3.029815652986,
2.234894233100,
1.884862677960,
2.703542617621,
2.188894869734,
1.031225637544,
4.487470294014,
1.916903861878,
2.178877764206,
2.001204233385,
1.668533128794,
0.118714387565,
1.236342841750,
0.697779517270,
4.061304247309,
1.873047854221,
0.529730720609,
0.772303413290,
1.734928501976,
0.830164961083,
3.674107591296,
3.027005867653,
2.798171180697,
2.754769626808,
2.287213251879,
0.224122591017,
1.996907607820,
2.272196861888,
1.423156951562,
2.649423732022,
2.410425004883,
2.348764499112,
4.188086272873,
2.592584804958,
1.360716155533,
1.089292416194,
0.877166635938,
2.923298927077,
1.699602289582,
1.764010718116,
0.851384613856,
1.362786130903,
4.014401248962,
2.004378924317,
2.680507997712,
4.162602009325,
2.080304752717,
0.758782969232,
0.896584126809,
1.907281638800,
2.753415491620,
1.571468221472,
1.510571435517,
3.133254430892,
1.314198176176,
2.871092309494,
0.505771497509,
0.608771053519,
0.099600620869,
2.202314023992,
1.561845986404,
1.935860544395,
4.227485606155,
2.507702606518,
1.966897273255,
3.462827375982,
2.297865682096,
2.018310409281,
2.231512822040,
2.912164920958,
0.391926284930,
3.233896921158,
2.270671144478,
2.151928087898,
1.169376547635,
1.410447269758,
1.104075308499,
-1.542633116467,
1.153006815104,
1.825678952144,
3.170518866440,
4.259372395300,
2.991591841969,
2.936827860147,
1.621450416535,
2.022035327270,
3.512668911326,
2.840069655471,
0.445725474197,
0.462229554454,
0.318918997270,
2.764048560322,
1.707769041832,
0.354635293838,
1.422103811424,
1.567812002847,
1.024884046523,
3.417171077354,
1.638428319488,
3.241722761084,
1.903144274531,
2.386560261127,
1.089737760638,
0.709565288091,
-0.055267123709,
2.220171017505,
2.676992914119,
1.795938808643,
0.857048483230,
2.341277450146,
0.597747299826,
2.172474110279,
1.658595631706,
1.984212673322,
3.348561751121,
1.548578130896,
1.072758349253,
0.032774558165,
1.706534108602,
0.755870027998,
3.896324551791,
1.893154948782,
1.610014175651,
1.689869260730,
0.837788921577,
2.072889043390,
1.849998133863,
2.312731199777,
3.080340939707,
3.091164029551,
0.393260795576,
2.003732199463,
1.850471999505,
1.674223365447,
0.950827266616,
0.893144704712,
4.088054520567,
2.494717669405,
2.940185915913,
2.362745036344,
4.420918853822,
1.196829910488,
0.751131585724,
2.572876053732,
4.783864101630,
1.371390533975,
2.265749507496,
0.980731387353,
1.194594017621,
1.167489912193,
1.964259577764,
2.911981147100,
3.425120291588,
-0.257485591786,
2.472881717211,
3.053440640390,
0.762578570358,
1.132958189893,
2.182874371350,
3.052476057575,
1.277863138274,
1.639136886663,
3.068422388091,
4.082802262329,
3.817537635954,
0.097850368917,
1.833230262781,
1.868086753582,
1.887983463294,
1.651402760749,
1.139536636683,
1.506983113398,
2.136499510660,
1.554089544528,
1.817657472715,
1.881949483974,
1.694259012321,
2.466961181010,
0.934064795958,
1.780986169906,
2.370334192643,
1.384364501906,
2.661332270733,
0.505133486534,
3.377981661644,
4.528998885749,
1.374759695170,
4.722230929249,
2.457241846607,
1.089449047113,
2.069442203989,
2.922047383746,
2.418239643214,
2.102706555829,
3.402947114317,
0.796159207619,
2.632564349894,
4.200094165974,
1.193215106012,
3.096716644943,
2.876829603210,
1.921543697812,
1.061475001173,
1.539143201636,
1.962758648643,
2.295863280945,
1.165666782756,
2.795634751169,
1.964614100540,
2.881578005533,
2.637037175067,
2.892982065300,
1.370612909045,
2.259776562066,
2.613792094772,
1.906706647250,
1.557148053231,
1.133780390845,
1.143533122599,
4.117191444375,
0.188018096004,
0.214460776257,
1.603522547618,
1.983864405185,
2.699735877141,
3.298710632472,
1.487986613587,
1.991452281956,
4.766298265537,
2.586190101112,
1.065148174656,
2.145271421777,
1.522765549323,
4.396072859233,
1.606095900438,
1.031438092798,
2.960703649068,
3.605096318253,
0.738490507896,
3.432893690638,
1.851816325195,
1.776083706595,
2.626748456025,
3.271038565612,
-0.070612895830,
1.287744477416,
1.695753443501,
2.436753770247,
2.202642339826,
1.408938299450,
2.016160912374,
0.898288687497,
1.289813203546,
2.817550580649,
1.633807646433,
0.729748534247,
3.731152016795,
4.390636356521,
0.960267728055,
2.664438938132,
3.353944579683,
3.269381121224,
1.172165464812,
2.400938318282,
0.807728448589,
1.354368440944,
0.710514838580,
3.856287086165,
1.844610913086,
0.998911164773,
2.675023215997,
2.434288374696,
3.159869294202,
2.216260317913,
1.045656776305,
2.009335242758,
-0.709434674480,
1.331363427990,
1.413333913927,
0.929006400187,
1.733184360011,
1.592926769452,
2.244402061817,
-0.685165133598,
1.588319181570,
2.023137693888,
2.371988587924,
3.467422121044,
1.853048214949,
2.619054775096,
2.800663994934,
3.602262658213,
2.739826236907,
1.657450448996,
2.635627781551,
3.426068129480,
1.151293859797,
2.126497682712,
1.381359388969,
2.247544386446,
1.174628386484,
0.958611426223,
3.704308164799,
1.239169433782,
1.113358149317,
2.955595260742,
4.724760227403,
1.980693866973,
1.178308425749,
1.764128983382,
3.075383932276,
1.825832517572,
0.832366525583,
1.141057225660,
1.525888435245,
2.385324115908,
1.358714765290,
0.551137984167,
3.731145479114,
2.129026962460,
2.573317010477,
1.976940325893,
2.420475072025,
0.684154421404,
2.802696725755,
2.541095686615,
2.058591811473,
1.640112285999,
1.856989192038,
2.614611193561,
3.469421336856,
1.053557146407,
3.032200283499,
2.573750297024,
3.083216685185,
2.404219296708,
0.346271398570,
2.589361666010,
2.774804416246,
0.445877540011,
0.905444077995,
0.063823875188,
2.931316420485,
1.682860197161,
2.972795382257,
2.597485175923,
1.554827252582,
0.938640710601,
1.554015447012,
0.698644188586,
2.957760202695,
2.706304471141,
2.642415006150,
1.464184232137,
2.765792229162,
2.039393447616,
1.582779254230,
1.722697961910,
0.354842490538,
0.839688308674,
3.250316830782,
3.993268587677,
1.831751003414,
3.737987669486,
3.837008408003,
3.656452995704,
1.378085850241,
3.992366605685,
3.063520565655,
1.829600671075,
2.853149829083,
1.948008763331,
2.489355654745,
2.039149456991,
1.723308108929,
1.530719515047,
1.390322318375,
1.015161747970,
1.902647551975,
0.587714373573,
2.419343238401,
2.037241109090,
1.989108845487,
2.555164211364,
2.145078634562,
2.453495232937,
1.572091583978,
3.017196269239,
1.359683738353,
1.905697793148,
1.745346338494,
2.410960789923,
1.688108090624,
2.041661869959,
2.261146892703,
0.108311227666,
2.261198590438,
1.205414457068,
0.815680644627,
2.373638547036,
0.314446220857,
2.407160216258,
1.767921824455,
1.812649838016,
1.981483340407,
2.294353826751,
1.219794724258,
4.384759314526,
3.362912919095,
0.358020839800,
2.416111296383,
0.772765268291,
3.036908153028,
3.499839475422,
2.504401672085,
1.000612753791,
1.031364523108,
2.905950640378,
3.816584440139,
0.846980443659,
2.806102343007,
1.662302388297,
2.146698147213,
2.247505312463,
1.485016638026,
3.139004503074,
0.525587710167,
1.271023854890,
1.255521130946,
1.814043296479,
1.216959307975,
0.978300004743,
1.793024541935,
2.436108214253,
1.805501508380,
2.289362542667,
3.103303146056,
3.070219780476,
1.928865588661,
0.671011951957,
1.892825933013,
2.777602823529,
0.491871575583,
2.240415846966,
2.375489703208,
2.709091612473,
1.454643174490,
0.932692202068,
1.330312119137,
2.127413235976,
1.317902934165,
1.580714395448,
1.008090918724,
0.713394722097,
0.414109615934,
1.497415539366,
2.768670845431,
2.432044584164,
3.549640318635,
1.342147007285,
1.490711835094,
2.215255822048,
2.953966963699,
1.397115922550,
1.378544056641,
0.295634610499,
0.310858641177,
1.767513113561,
3.434648852323,
1.491911596249,
4.374871362485,
1.373675010945,
0.738310910553,
1.234191541434,
2.155481175306,
2.958616497624,
1.540019317971,
1.890919744323,
0.015363864461,
0.611976171745,
2.048461203755,
0.905204536881,
1.952638485996,
2.425065685214,
2.237320237401,
1.261771567053,
2.589404719675,
2.475731886267,
1.327151422229,
1.535419742175,
1.208022763652,
3.436317329939,
1.228705365902,
1.566902016116,
2.085976618587,
3.604339608476,
1.070321479131,
2.085842592869,
2.524588738973,
0.844371573275,
0.666896658382,
3.021051396651,
1.304696763442,
2.885533651158,
1.076496681998,
2.291817051246,
1.297874264925,
2.181105432748,
2.017938562177,
1.688714920892,
2.778982555151,
2.464336632348,
2.351157814765,
0.328615108417,
2.045729524665,
1.097213832650,
1.821737535398,
1.493311782343,
2.664732414197,
2.409262056717,
3.113216373314,
-0.736586322990,
2.622696405459,
2.182304470993,
0.270476186892,
2.628658716523,
1.938159535287,
3.336990992309,
1.843049113291,
4.411231162672,
2.385234723754,
2.814668480292,
0.835204823261,
1.486415794299,
4.428506992622,
2.010951035701,
2.386497902232,
-0.054343008749,
1.897798394390,
2.128378005079,
2.003863798772,
1.414857649717,
3.058382706867,
0.734980076150,
0.128402966890,
0.075621261496,
2.062530850812,
2.257054591626,
3.098405063129,
1.184503294303,
1.927098840462,
3.590105219538,
2.324770104189,
2.920547827923,
3.774469427430,
0.643975980439,
2.972011913013,
1.545636773552,
1.284276446577,
2.116456504846,
2.334765924705,
1.476322485264,
0.333938454195,
1.740780860437,
0.809641636242,
2.114359904589,
3.495010537745,
1.394057058959,
2.099880999687,
1.723136191694,
1.824145520142,
2.206175560435,
2.217935160800,
3.184151380365,
-0.165277107839,
2.066902569755,
4.109207223806,
2.639922346758,
2.869289441530,
2.992666432223,
2.628328010580,
1.318946413946,
3.437097382310,
2.043254488710,
0.244000873823,
1.857441713051,
1.302602111278,
2.850286225242,
1.988609208476,
0.406765856788,
1.691073499692,
0.918912799942,
1.943198487145,
3.174415802822,
1.916755816708,
2.196550794119,
0.930720476044,
2.032189015326,
-0.777034945338,
1.406753268550,
0.870345844705,
1.793195283464,
2.066120080070,
2.916729217526,
0.642313449142,
2.617529572000,
2.396572272668,
1.942111542427,
2.435603256612,
3.898219347884,
1.979409214342,
1.235681010137,
-0.802441600645,
1.927883866070,
0.852232772749,
2.626513188209,
1.994232584644,
2.677125120554,
2.945149227801,
0.344859114264,
2.988484765052,
2.221699681734,
1.157038942208,
2.703070759809,
1.410436365113,
3.056534135285,
0.975232183559,
1.032651705560,
1.787301763233,
1.587502529729,
1.425207628405,
2.443158189935,
3.786205343468,
0.240451061053,
2.993759767949,
2.527525916677,
2.990777291349,
1.458774147434,
4.293524428909,
-0.116618748162,
1.674243883127,
2.434026351267,
3.129729749455,
1.532120640786,
3.584008627649,
2.126682783899,
0.784920593215,
1.954841166456,
2.659877218373,
2.639038968190,
3.009597452617,
3.820422929562,
2.950718556164,
2.942026969809,
2.899140330708,
-0.003511099104,
0.780849789152,
2.375904463772,
1.034820493941,
2.010379907777,
2.273452795908,
2.508893511243,
0.495773521197,
3.010585297044,
3.029210010516,
3.973880821070,
2.416599047057,
1.320773195864,
-0.296283555404,
3.112367101202,
0.568454165534,
3.950197901953,
3.040255296379,
2.892209686169,
1.355195417805,
2.139684432822,
2.920582903729,
1.588899963320,
2.235959314499,
1.768769790964,
2.854126298598,
-0.132279995647,
1.984901818097,
0.667875459687,
1.338320000460,
1.394322304858,
0.299399873843,
1.062140558649,
1.283070416608,
2.043726915244,
1.426628725160,
1.763445352183,
2.517159283156,
1.334042379945,
1.120896888394,
1.890222921582,
0.565772476594,
1.106579774451,
1.419654511698,
2.809182593659,
2.500132279723,
2.818415931740,
2.302096389328,
2.700248827229,
2.649016952991,
3.051337084118,
1.040839832658,
1.068258609432,
3.917982023425,
0.893981534117,
1.258354231702,
0.154108914546,
1.578873281706,
1.438841948587,
0.854591114516,
3.199994919222,
0.946279793861,
1.911631903701,
3.821301368668,
1.417597738430,
2.979514439880,
1.202734703576,
1.724395921518,
3.069121580556,
1.024707141488,
3.309560047408,
2.147433150108,
2.173493008341,
3.875831804386,
2.160166379458,
2.017326408423,
3.941632320127,
2.583832116740,
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment