Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d1b91ba9
Unverified
Commit
d1b91ba9
authored
Jun 18, 2018
by
Chris Shallue
Committed by
GitHub
Jun 18, 2018
Browse files
Merge pull request #4554 from cshallue/master
Merge internal changes
parents
1abcee90
8ae37506
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
272 additions
and
106 deletions
+272
-106
research/astronet/light_curve_util/util.py
research/astronet/light_curve_util/util.py
+12
-20
research/astronet/light_curve_util/util_test.py
research/astronet/light_curve_util/util_test.py
+46
-23
research/astronet/third_party/kepler_spline/kepler_spline.py
research/astronet/third_party/kepler_spline/kepler_spline.py
+112
-36
research/astronet/third_party/kepler_spline/kepler_spline_test.py
.../astronet/third_party/kepler_spline/kepler_spline_test.py
+102
-27
No files found.
research/astronet/light_curve_util/util.py
View file @
d1b91ba9
...
@@ -18,8 +18,6 @@ from __future__ import absolute_import
...
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
# pylint:disable=redefined-builtin
from
six.moves
import
range
# pylint:disable=redefined-builtin
...
@@ -51,10 +49,10 @@ def split(all_time, all_flux, gap_width=0.75):
...
@@ -51,10 +49,10 @@ def split(all_time, all_flux, gap_width=0.75):
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args:
Args:
all_time: Numpy array or
list
of numpy arrays; each is a sequence of
time
all_time: Numpy array or
sequence
of numpy arrays; each is a sequence of
values.
time
values.
all_flux: Numpy array or
list
of numpy arrays; each is a sequence of
flux
all_flux: Numpy array or
sequence
of numpy arrays; each is a sequence of
values of the corresponding time array.
flux
values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split.
gap_width: Minimum gap size (in time units) for a split.
Returns:
Returns:
...
@@ -62,10 +60,7 @@ def split(all_time, all_flux, gap_width=0.75):
...
@@ -62,10 +60,7 @@ def split(all_time, all_flux, gap_width=0.75):
out_flux: List of numpy arrays; the split flux arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
"""
# Handle single-segment inputs.
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
if
isinstance
(
all_time
,
np
.
ndarray
)
and
all_time
.
ndim
==
1
:
# to bool fails if all_time is a numpy array, and all_time.size is not defined
# if all_time is a list of numpy arrays.
if
len
(
all_time
)
>
0
and
not
isinstance
(
all_time
[
0
],
collections
.
Iterable
):
# pylint:disable=g-explicit-length-test
all_time
=
[
all_time
]
all_time
=
[
all_time
]
all_flux
=
[
all_flux
]
all_flux
=
[
all_flux
]
...
@@ -90,10 +85,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
...
@@ -90,10 +85,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args:
Args:
all_time: Numpy array or
list
of numpy arrays; each is a sequence of
time
all_time: Numpy array or
sequence
of numpy arrays; each is a sequence of
values.
time
values.
all_flux: Numpy array or
list
of numpy arrays; each is a sequence of
flux
all_flux: Numpy array or
sequence
of numpy arrays; each is a sequence of
values of the corresponding time array.
flux
values of the corresponding time array.
events: List of Event objects to remove.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
...
@@ -104,10 +99,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
...
@@ -104,10 +99,7 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
events removed.
events removed.
"""
"""
# Handle single-segment inputs.
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
if
isinstance
(
all_time
,
np
.
ndarray
)
and
all_time
.
ndim
==
1
:
# to bool fails if all_time is a numpy array and all_time.size is not defined
# if all_time is a list of numpy arrays.
if
len
(
all_time
)
>
0
and
not
isinstance
(
all_time
[
0
],
collections
.
Iterable
):
# pylint:disable=g-explicit-length-test
all_time
=
[
all_time
]
all_time
=
[
all_time
]
all_flux
=
[
all_flux
]
all_flux
=
[
all_flux
]
single_segment
=
True
single_segment
=
True
...
@@ -150,10 +142,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
...
@@ -150,10 +142,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
interp_spline
=
[]
interp_spline
=
[]
for
time
,
masked_time
,
masked_spline
in
zip
(
for
time
,
masked_time
,
masked_spline
in
zip
(
all_time
,
all_masked_time
,
all_masked_spline
):
all_time
,
all_masked_time
,
all_masked_spline
):
if
len
(
masked_time
)
>
0
:
# pylint:disable=g-explicit-length-test
if
masked_time
.
size
:
interp_spline
.
append
(
np
.
interp
(
time
,
masked_time
,
masked_spline
))
interp_spline
.
append
(
np
.
interp
(
time
,
masked_time
,
masked_spline
))
else
:
else
:
interp_spline
.
append
(
np
.
full_like
(
time
,
np
.
nan
))
interp_spline
.
append
(
np
.
array
([
np
.
nan
]
*
len
(
time
)
))
return
interp_spline
return
interp_spline
...
...
research/astronet/light_curve_util/util_test.py
View file @
d1b91ba9
...
@@ -65,43 +65,63 @@ class LightCurveUtilTest(absltest.TestCase):
...
@@ -65,43 +65,63 @@ class LightCurveUtilTest(absltest.TestCase):
self
.
assertSequenceAlmostEqual
(
expected
,
tfold
)
self
.
assertSequenceAlmostEqual
(
expected
,
tfold
)
def
testSplit
(
self
):
def
testSplit
(
self
):
# Single segment.
all_time
=
np
.
concatenate
([
np
.
arange
(
0
,
1
,
0.1
),
np
.
arange
(
1.5
,
2
,
0.1
)])
all_flux
=
np
.
ones
(
15
)
# Gap width 0.5.
split_time
,
split_flux
=
util
.
split
(
all_time
,
all_flux
,
gap_width
=
0.5
)
self
.
assertLen
(
split_time
,
2
)
self
.
assertLen
(
split_flux
,
2
)
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
0
,
1
,
0.1
),
split_time
[
0
])
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
10
),
split_flux
[
0
])
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
1.5
,
2
,
0.1
),
split_time
[
1
])
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
5
),
split_flux
[
1
])
# Multi segment.
all_time
=
[
all_time
=
[
np
.
concatenate
([
np
.
concatenate
([
np
.
arange
(
0
,
1
,
0.1
),
np
.
arange
(
0
,
1
,
0.1
),
np
.
arange
(
1.5
,
2
,
0.1
),
np
.
arange
(
1.5
,
2
,
0.1
),
np
.
arange
(
3
,
4
,
0.1
)
np
.
arange
(
3
,
4
,
0.1
)
])
]),
np
.
arange
(
4
,
5
,
0.1
)
]
]
all_flux
=
[
np
.
array
([
1
]
*
25
)]
all_flux
=
[
np
.
ones
(
25
),
np
.
ones
(
10
)]
self
.
assertEqual
(
len
(
all_time
),
2
)
self
.
assertEqual
(
len
(
all_time
[
0
]),
25
)
self
.
assertEqual
(
len
(
all_time
[
1
]),
10
)
self
.
assertEqual
(
len
(
all_flux
),
2
)
self
.
assertEqual
(
len
(
all_flux
[
0
]),
25
)
self
.
assertEqual
(
len
(
all_flux
[
1
]),
10
)
# Gap width 0.5.
# Gap width 0.5.
split_time
,
split_flux
=
util
.
split
(
all_time
,
all_flux
,
gap_width
=
0.5
)
split_time
,
split_flux
=
util
.
split
(
all_time
,
all_flux
,
gap_width
=
0.5
)
self
.
assertLen
(
split_time
,
3
)
self
.
assertLen
(
split_time
,
4
)
self
.
assertLen
(
split_flux
,
3
)
self
.
assertLen
(
split_flux
,
4
)
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
0
,
1
,
0.1
),
split_time
[
0
])
[
0.
,
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
,
0.7
,
0.8
,
0.9
],
split_time
[
0
])
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
10
),
split_flux
[
0
])
self
.
assertSequenceAlmostEqual
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
1.5
,
2
,
0.1
),
split_time
[
1
])
split_flux
[
0
])
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
5
),
split_flux
[
1
])
self
.
assertSequenceAlmostEqual
([
1.5
,
1.6
,
1.7
,
1.8
,
1.9
],
split_time
[
1
])
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
3
,
4
,
0.1
),
split_time
[
2
])
self
.
assertSequenceAlmostEqual
([
1
,
1
,
1
,
1
,
1
],
split_flux
[
1
])
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
10
),
split_flux
[
2
])
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
4
,
5
,
0.1
),
split_time
[
3
])
[
3.
,
3.1
,
3.2
,
3.3
,
3.4
,
3.5
,
3.6
,
3.7
,
3.8
,
3.9
],
split_time
[
2
])
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
10
),
split_flux
[
3
])
self
.
assertSequenceAlmostEqual
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
split_flux
[
2
])
# Gap width 1.0.
# Gap width 1.0.
split_time
,
split_flux
=
util
.
split
(
all_time
,
all_flux
,
gap_width
=
1
)
split_time
,
split_flux
=
util
.
split
(
all_time
,
all_flux
,
gap_width
=
1
)
self
.
assertLen
(
split_time
,
2
)
self
.
assertLen
(
split_time
,
3
)
self
.
assertLen
(
split_flux
,
2
)
self
.
assertLen
(
split_flux
,
3
)
self
.
assertSequenceAlmostEqual
([
self
.
assertSequenceAlmostEqual
([
0.
,
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
,
0.7
,
0.8
,
0.9
,
1.5
,
1.6
,
1.7
,
1.8
,
1.9
0.
,
0.1
,
0.2
,
0.3
,
0.4
,
0.5
,
0.6
,
0.7
,
0.8
,
0.9
,
1.5
,
1.6
,
1.7
,
1.8
,
1.9
],
split_time
[
0
])
],
split_time
[
0
])
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
15
),
split_flux
[
0
])
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
split_flux
[
0
])
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
3
,
4
,
0.1
),
split_time
[
1
])
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
10
),
split_flux
[
1
])
[
3.
,
3.1
,
3.2
,
3.3
,
3.4
,
3.5
,
3.6
,
3.7
,
3.8
,
3.9
],
split_time
[
1
])
self
.
assertSequenceAlmostEqual
(
np
.
arange
(
4
,
5
,
0.1
),
split_time
[
2
])
self
.
assertSequenceAlmostEqual
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
self
.
assertSequenceAlmostEqual
(
np
.
ones
(
10
),
split_flux
[
2
])
split_flux
[
1
])
def
testRemoveEvents
(
self
):
def
testRemoveEvents
(
self
):
time
=
np
.
arange
(
20
,
dtype
=
np
.
float
)
time
=
np
.
arange
(
20
,
dtype
=
np
.
float
)
...
@@ -136,20 +156,23 @@ class LightCurveUtilTest(absltest.TestCase):
...
@@ -136,20 +156,23 @@ class LightCurveUtilTest(absltest.TestCase):
all_time
=
[
all_time
=
[
np
.
arange
(
0
,
10
,
dtype
=
np
.
float
),
np
.
arange
(
0
,
10
,
dtype
=
np
.
float
),
np
.
arange
(
10
,
20
,
dtype
=
np
.
float
),
np
.
arange
(
10
,
20
,
dtype
=
np
.
float
),
np
.
arange
(
20
,
30
,
dtype
=
np
.
float
),
]
]
all_masked_time
=
[
all_masked_time
=
[
np
.
array
([
0
,
1
,
2
,
3
,
8
,
9
],
dtype
=
np
.
float
),
# No 4, 5, 6, 7
np
.
array
([
0
,
1
,
2
,
3
,
8
,
9
],
dtype
=
np
.
float
),
# No 4, 5, 6, 7
np
.
array
([
10
,
11
,
12
,
13
,
14
,
15
,
16
],
dtype
=
np
.
float
),
# No 17, 18, 19
np
.
array
([
10
,
11
,
12
,
13
,
14
,
15
,
16
],
dtype
=
np
.
float
),
# No 17, 18, 19
np
.
array
([],
dtype
=
np
.
float
)
]
]
all_masked_spline
=
[
2
*
t
+
100
for
t
in
all_masked_time
]
all_masked_spline
=
[
2
*
t
+
100
for
t
in
all_masked_time
]
interp_spline
=
util
.
interpolate_masked_spline
(
all_time
,
all_masked_time
,
interp_spline
=
util
.
interpolate_masked_spline
(
all_time
,
all_masked_time
,
all_masked_spline
)
all_masked_spline
)
self
.
assertLen
(
interp_spline
,
2
)
self
.
assertLen
(
interp_spline
,
3
)
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
[
100
,
102
,
104
,
106
,
108
,
110
,
112
,
114
,
116
,
118
],
interp_spline
[
0
])
[
100
,
102
,
104
,
106
,
108
,
110
,
112
,
114
,
116
,
118
],
interp_spline
[
0
])
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
[
120
,
122
,
124
,
126
,
128
,
130
,
132
,
132
,
132
,
132
],
interp_spline
[
1
])
[
120
,
122
,
124
,
126
,
128
,
130
,
132
,
132
,
132
,
132
],
interp_spline
[
1
])
self
.
assertTrue
(
np
.
all
(
np
.
isnan
(
interp_spline
[
2
])))
def
testCountTransitPoints
(
self
):
def
testCountTransitPoints
(
self
):
time
=
np
.
concatenate
([
time
=
np
.
concatenate
([
...
...
research/astronet/third_party/kepler_spline/kepler_spline.py
View file @
d1b91ba9
...
@@ -12,8 +12,13 @@ from pydl.pydlutils import bspline
...
@@ -12,8 +12,13 @@ from pydl.pydlutils import bspline
from
third_party.robust_mean
import
robust_mean
from
third_party.robust_mean
import
robust_mean
class
InsufficientPointsError
(
Exception
):
"""Indicates that insufficient points were available for spline fitting."""
pass
class
SplineError
(
Exception
):
class
SplineError
(
Exception
):
"""
Error when fitting a Kepler spline
."""
"""
Indicates an error in the underlying spline-fitting implementation
."""
pass
pass
...
@@ -40,7 +45,17 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
...
@@ -40,7 +45,17 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
spline: The values of the fitted spline corresponding to the input time
spline: The values of the fitted spline corresponding to the input time
values.
values.
mask: Boolean mask indicating the points used to fit the final spline.
mask: Boolean mask indicating the points used to fit the final spline.
Raises:
InsufficientPointsError: If there were insufficient points (after removing
outliers) for spline fitting.
SplineError: If the spline could not be fit, for example if the breakpoint
spacing is too small.
"""
"""
if
len
(
time
)
<
4
:
raise
InsufficientPointsError
(
"Cannot fit a spline on less than 4 points. Got %d points."
%
len
(
time
))
# Rescale time into [0, 1].
# Rescale time into [0, 1].
t_min
=
np
.
min
(
time
)
t_min
=
np
.
min
(
time
)
t_max
=
np
.
max
(
time
)
t_max
=
np
.
max
(
time
)
...
@@ -58,16 +73,26 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
...
@@ -58,16 +73,26 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
mask
=
np
.
ones_like
(
time
,
dtype
=
np
.
bool
)
# Try to fit all points.
mask
=
np
.
ones_like
(
time
,
dtype
=
np
.
bool
)
# Try to fit all points.
else
:
else
:
# Choose points where the absolute deviation from the median residual is
# Choose points where the absolute deviation from the median residual is
# less than
3
*sigma, where sigma is a robust estimate of the
standard
# less than
outlier_cut
*sigma, where sigma is a robust estimate of the
# deviation of the residuals from the previous spline.
#
standard
deviation of the residuals from the previous spline.
residuals
=
flux
-
spline
residuals
=
flux
-
spline
_
,
_
,
new_mask
=
robust_mean
.
robust_mean
(
residuals
,
cut
=
outlier_cut
)
new_mask
=
robust_mean
.
robust_mean
(
residuals
,
cut
=
outlier_cut
)
[
2
]
if
np
.
all
(
new_mask
==
mask
):
if
np
.
all
(
new_mask
==
mask
):
break
# Spline converged.
break
# Spline converged.
mask
=
new_mask
mask
=
new_mask
if
np
.
sum
(
mask
)
<
4
:
# Fewer than 4 points after removing outliers. We could plausibly return
# the spline from the previous iteration because it was fit with at least
# 4 points. However, since the outliers were such a significant fraction
# of the curve, the spline from the previous iteration is probably junk,
# and we consider this a fatal error.
raise
InsufficientPointsError
(
"Cannot fit a spline on less than 4 points. After removing "
"outliers, got %d points."
%
np
.
sum
(
mask
))
try
:
try
:
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
# Suppress warning messages printed by pydlutils.bspline. Instead we
# Suppress warning messages printed by pydlutils.bspline. Instead we
...
@@ -88,6 +113,31 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
...
@@ -88,6 +113,31 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
return
spline
,
mask
return
spline
,
mask
class
SplineMetadata
(
object
):
"""Metadata about a spline fit.
Attributes:
light_curve_mask: List of boolean numpy arrays indicating which points in
the light curve were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
likelihood_term: The likelihood term of the Bayesian Information Criterion;
-2*ln(L), where L is the likelihood of the data given the model.
penalty_term: The penalty term for the number of parameters in the
Bayesian Information Criterion.
bic: The value of the Bayesian Information Criterion; equal to
likelihood_term + penalty_coeff * penalty_term.
"""
def
__init__
(
self
):
self
.
light_curve_mask
=
None
self
.
bkspace
=
None
self
.
bad_bkspaces
=
[]
self
.
likelihood_term
=
None
self
.
penalty_term
=
None
self
.
bic
=
None
def
choose_kepler_spline
(
all_time
,
def
choose_kepler_spline
(
all_time
,
all_flux
,
all_flux
,
bkspaces
,
bkspaces
,
...
@@ -125,52 +175,65 @@ def choose_kepler_spline(all_time,
...
@@ -125,52 +175,65 @@ def choose_kepler_spline(all_time,
Returns:
Returns:
spline: List of numpy arrays; values of the best-fit spline corresponding to
spline: List of numpy arrays; values of the best-fit spline corresponding to
to the input flux arrays.
to the input flux arrays.
spline_mask: List of boolean numpy arrays indicating which points in the
metadata: Object containing metadata about the spline fit.
flux arrays were used to fit the best-fit spline.
bkspace: The break-point spacing used for the best-fit spline.
bad_bkspaces: List of break-point spacing values that failed.
"""
"""
# Initialize outputs.
best_spline
=
None
metadata
=
SplineMetadata
()
# Compute the assumed standard deviation of Gaussian white noise about the
# Compute the assumed standard deviation of Gaussian white noise about the
# spline model.
# spline model. We assume that each flux value f[i] is a Gaussian random
abs_deviations
=
np
.
concatenate
([
np
.
abs
(
f
[
1
:]
-
f
[:
-
1
])
for
f
in
all_flux
])
# variable f[i] ~ N(s[i], sigma^2), where s is the value of the true spline
sigma
=
np
.
median
(
abs_deviations
)
*
1.48
/
np
.
sqrt
(
2
)
# model and sigma is the constant standard deviation for all flux values.
# Moreover, we assume that s[i] ~= s[i+1]. Therefore,
# (f[i+1] - f[i]) / sqrt(2) ~ N(0, sigma^2).
scaled_diffs
=
np
.
concatenate
([
np
.
diff
(
f
)
/
np
.
sqrt
(
2
)
for
f
in
all_flux
])
if
not
scaled_diffs
.
size
:
best_spline
=
[
np
.
array
([
np
.
nan
]
*
len
(
f
))
for
f
in
all_flux
]
metadata
.
light_curve_mask
=
[
np
.
zeros_like
(
f
,
dtype
=
np
.
bool
)
for
f
in
all_flux
]
return
best_spline
,
metadata
# Compute the median absoute deviation as a robust estimate of sigma. The
# conversion factor of 1.48 takes the median absolute deviation to the
# standard deviation of a normal distribution. See, e.g.
# https://www.mathworks.com/help/stats/mad.html.
sigma
=
np
.
median
(
np
.
abs
(
scaled_diffs
))
*
1.48
best_bic
=
None
best_spline
=
None
best_spline_mask
=
None
best_bkspace
=
None
bad_bkspaces
=
[]
for
bkspace
in
bkspaces
:
for
bkspace
in
bkspaces
:
nparams
=
0
# Total number of free parameters in the piecewise spline.
nparams
=
0
# Total number of free parameters in the piecewise spline.
npoints
=
0
# Total number of data points used to fit the piecewise spline.
npoints
=
0
# Total number of data points used to fit the piecewise spline.
ssr
=
0
# Sum of squared residuals between the model and the spline.
ssr
=
0
# Sum of squared residuals between the model and the spline.
spline
=
[]
spline
=
[]
splin
e_mask
=
[]
light_curv
e_mask
=
[]
bad_bkspace
=
False
# Indicates that the current bkspace should be skipped.
bad_bkspace
=
False
# Indicates that the current bkspace should be skipped.
for
time
,
flux
in
zip
(
all_time
,
all_flux
):
for
time
,
flux
in
zip
(
all_time
,
all_flux
):
# Don't fit a spline on less than 4 points.
if
len
(
time
)
<
4
:
spline
.
append
(
flux
)
spline_mask
.
append
(
np
.
ones_like
(
flux
),
dtype
=
np
.
bool
)
continue
# Fit B-spline to this light-curve segment.
# Fit B-spline to this light-curve segment.
try
:
try
:
spline_piece
,
mask
=
kepler_spline
(
spline_piece
,
mask
=
kepler_spline
(
time
,
flux
,
bkspace
=
bkspace
,
maxiter
=
maxiter
)
time
,
flux
,
bkspace
=
bkspace
,
maxiter
=
maxiter
)
except
InsufficientPointsError
as
e
:
# It's expected to get a SplineError occasionally for small values of
# It's expected to occasionally see intervals with insufficient points,
# bkspace.
# especially if periodic signals have been removed from the light curve.
# Skip this interval, but continue fitting the spline.
if
verbose
:
warnings
.
warn
(
str
(
e
))
spline
.
append
(
np
.
array
([
np
.
nan
]
*
len
(
flux
)))
light_curve_mask
.
append
(
np
.
zeros_like
(
flux
,
dtype
=
np
.
bool
))
continue
except
SplineError
as
e
:
except
SplineError
as
e
:
# It's expected to get a SplineError occasionally for small values of
# bkspace. Skip this bkspace.
if
verbose
:
if
verbose
:
warnings
.
warn
(
"Bad bkspace %.4f: %s"
%
(
bkspace
,
e
))
warnings
.
warn
(
"Bad bkspace %.4f: %s"
%
(
bkspace
,
e
))
bad_bkspaces
.
append
(
bkspace
)
metadata
.
bad_bkspaces
.
append
(
bkspace
)
bad_bkspace
=
True
bad_bkspace
=
True
break
break
spline
.
append
(
spline_piece
)
spline
.
append
(
spline_piece
)
splin
e_mask
.
append
(
mask
)
light_curv
e_mask
.
append
(
mask
)
# Accumulate the number of free parameters.
# Accumulate the number of free parameters.
total_time
=
np
.
max
(
time
)
-
np
.
min
(
time
)
total_time
=
np
.
max
(
time
)
-
np
.
min
(
time
)
...
@@ -181,7 +244,7 @@ def choose_kepler_spline(all_time,
...
@@ -181,7 +244,7 @@ def choose_kepler_spline(all_time,
npoints
+=
np
.
sum
(
mask
)
npoints
+=
np
.
sum
(
mask
)
ssr
+=
np
.
sum
((
flux
[
mask
]
-
spline_piece
[
mask
])
**
2
)
ssr
+=
np
.
sum
((
flux
[
mask
]
-
spline_piece
[
mask
])
**
2
)
if
bad_bkspace
:
if
bad_bkspace
or
not
npoints
:
continue
continue
# The following term is -2*ln(L), where L is the likelihood of the data
# The following term is -2*ln(L), where L is the likelihood of the data
...
@@ -189,13 +252,26 @@ def choose_kepler_spline(all_time,
...
@@ -189,13 +252,26 @@ def choose_kepler_spline(all_time,
# Gaussian with mean 0 and standard deviation sigma.
# Gaussian with mean 0 and standard deviation sigma.
likelihood_term
=
npoints
*
np
.
log
(
2
*
np
.
pi
*
sigma
**
2
)
+
ssr
/
sigma
**
2
likelihood_term
=
npoints
*
np
.
log
(
2
*
np
.
pi
*
sigma
**
2
)
+
ssr
/
sigma
**
2
# Penalty term for the number of parameters used to fit the model.
penalty_term
=
nparams
*
np
.
log
(
npoints
)
# Bayesian information criterion.
# Bayesian information criterion.
bic
=
likelihood_term
+
penalty_coeff
*
nparams
*
np
.
log
(
npoints
)
bic
=
likelihood_term
+
penalty_coeff
*
penalty_term
if
best_bic
is
None
or
bic
<
best_bic
:
if
best_spline
is
None
or
bic
<
metadata
.
bic
:
best_bic
=
bic
best_spline
=
spline
best_spline
=
spline
best_spline_mask
=
spline_mask
metadata
.
light_curve_mask
=
light_curve_mask
best_bkspace
=
bkspace
metadata
.
bkspace
=
bkspace
metadata
.
likelihood_term
=
likelihood_term
return
best_spline
,
best_spline_mask
,
best_bkspace
,
bad_bkspaces
metadata
.
penalty_term
=
penalty_term
metadata
.
bic
=
bic
if
best_spline
is
None
:
# All bkspaces resulted in a SplineError, or all light curve intervals had
# insufficient points.
best_spline
=
[
np
.
array
([
np
.
nan
]
*
len
(
f
))
for
f
in
all_flux
]
metadata
.
light_curve_mask
=
[
np
.
zeros_like
(
f
,
dtype
=
np
.
bool
)
for
f
in
all_flux
]
return
best_spline
,
metadata
research/astronet/third_party/kepler_spline/kepler_spline_test.py
View file @
d1b91ba9
...
@@ -12,7 +12,7 @@ from third_party.kepler_spline import kepler_spline
...
@@ -12,7 +12,7 @@ from third_party.kepler_spline import kepler_spline
class
KeplerSplineTest
(
absltest
.
TestCase
):
class
KeplerSplineTest
(
absltest
.
TestCase
):
def
test
KeplerSpline
Sine
(
self
):
def
test
Fit
Sine
(
self
):
# Fit a sine wave.
# Fit a sine wave.
time
=
np
.
arange
(
0
,
10
,
0.1
)
time
=
np
.
arange
(
0
,
10
,
0.1
)
flux
=
np
.
sin
(
time
)
flux
=
np
.
sin
(
time
)
...
@@ -46,7 +46,7 @@ class KeplerSplineTest(absltest.TestCase):
...
@@ -46,7 +46,7 @@ class KeplerSplineTest(absltest.TestCase):
self
.
assertFalse
(
mask
[
77
])
self
.
assertFalse
(
mask
[
77
])
self
.
assertFalse
(
mask
[
95
])
self
.
assertFalse
(
mask
[
95
])
def
test
KeplerSpline
Cubic
(
self
):
def
test
Fit
Cubic
(
self
):
# Fit a cubic polynomial.
# Fit a cubic polynomial.
time
=
np
.
arange
(
0
,
10
,
0.1
)
time
=
np
.
arange
(
0
,
10
,
0.1
)
flux
=
(
time
-
5
)
**
3
+
2
*
(
time
-
5
)
**
2
+
10
flux
=
(
time
-
5
)
**
3
+
2
*
(
time
-
5
)
**
2
+
10
...
@@ -61,18 +61,84 @@ class KeplerSplineTest(absltest.TestCase):
...
@@ -61,18 +61,84 @@ class KeplerSplineTest(absltest.TestCase):
self
.
assertLess
(
rmse
,
1e-12
)
self
.
assertLess
(
rmse
,
1e-12
)
self
.
assertTrue
(
np
.
all
(
mask
))
self
.
assertTrue
(
np
.
all
(
mask
))
def
testKeplerSplineError
(
self
):
def
testInsufficientPointsError
(
self
):
# Big gap.
# Empty light curve.
time
=
np
.
concatenate
([
np
.
arange
(
0
,
1
,
0.1
),
[
2
]])
time
=
np
.
array
([])
flux
=
np
.
array
([])
with
self
.
assertRaises
(
kepler_spline
.
InsufficientPointsError
):
kepler_spline
.
kepler_spline
(
time
,
flux
,
bkspace
=
0.5
)
# Only 3 points.
time
=
np
.
array
([
0.1
,
0.2
,
0.3
])
flux
=
np
.
sin
(
time
)
flux
=
np
.
sin
(
time
)
with
self
.
assertRaises
(
kepler_spline
.
Spline
Error
):
with
self
.
assertRaises
(
kepler_spline
.
InsufficientPoints
Error
):
kepler_spline
.
kepler_spline
(
time
,
flux
,
bkspace
=
0.5
)
kepler_spline
.
kepler_spline
(
time
,
flux
,
bkspace
=
0.5
)
def
testChooseKeplerSpline
(
self
):
class
ChooseKeplerSplineTest
(
absltest
.
TestCase
):
def
testNoPoints
(
self
):
all_time
=
[
np
.
array
([])]
all_flux
=
[
np
.
array
([])]
# Logarithmically sample candidate break point spacings.
bkspaces
=
np
.
logspace
(
np
.
log10
(
0.5
),
np
.
log10
(
5
),
num
=
20
)
spline
,
metadata
=
kepler_spline
.
choose_kepler_spline
(
all_time
,
all_flux
,
bkspaces
,
penalty_coeff
=
1.0
,
verbose
=
False
)
np
.
testing
.
assert_array_equal
(
spline
,
[[]])
np
.
testing
.
assert_array_equal
(
metadata
.
light_curve_mask
,
[[]])
def
testTooFewPoints
(
self
):
# Sine wave with segments of 1, 2, 3 points.
all_time
=
[
np
.
array
([
0.1
]),
np
.
array
([
0.2
,
0.3
]),
np
.
array
([
0.4
,
0.5
,
0.6
])
]
all_flux
=
[
np
.
sin
(
t
)
for
t
in
all_time
]
# Logarithmically sample candidate break point spacings.
bkspaces
=
np
.
logspace
(
np
.
log10
(
0.5
),
np
.
log10
(
5
),
num
=
20
)
spline
,
metadata
=
kepler_spline
.
choose_kepler_spline
(
all_time
,
all_flux
,
bkspaces
,
penalty_coeff
=
1.0
,
verbose
=
False
)
# All segments are NaN.
self
.
assertTrue
(
np
.
all
(
np
.
isnan
(
np
.
concatenate
(
spline
))))
self
.
assertFalse
(
np
.
any
(
np
.
concatenate
(
metadata
.
light_curve_mask
)))
self
.
assertIsNone
(
metadata
.
bkspace
)
self
.
assertEmpty
(
metadata
.
bad_bkspaces
)
self
.
assertIsNone
(
metadata
.
likelihood_term
)
self
.
assertIsNone
(
metadata
.
penalty_term
)
self
.
assertIsNone
(
metadata
.
bic
)
# Add a longer segment.
all_time
.
append
(
np
.
arange
(
0.7
,
2.0
,
0.1
))
all_flux
.
append
(
np
.
sin
(
all_time
[
-
1
]))
spline
,
metadata
=
kepler_spline
.
choose_kepler_spline
(
all_time
,
all_flux
,
bkspaces
,
penalty_coeff
=
1.0
,
verbose
=
False
)
# First 3 segments are NaN.
for
i
in
range
(
3
):
self
.
assertTrue
(
np
.
all
(
np
.
isnan
(
spline
[
i
])))
self
.
assertFalse
(
np
.
any
(
metadata
.
light_curve_mask
[
i
]))
# Final segment is a good fit.
self
.
assertTrue
(
np
.
all
(
np
.
isfinite
(
spline
[
3
])))
self
.
assertTrue
(
np
.
all
(
metadata
.
light_curve_mask
[
3
]))
self
.
assertEmpty
(
metadata
.
bad_bkspaces
)
self
.
assertAlmostEqual
(
metadata
.
likelihood_term
,
-
58.0794069927957
)
self
.
assertAlmostEqual
(
metadata
.
penalty_term
,
7.69484807238461
)
self
.
assertAlmostEqual
(
metadata
.
bic
,
-
50.3845589204111
)
def
testFitSine
(
self
):
# High frequency sine wave.
# High frequency sine wave.
time
=
[
np
.
arange
(
0
,
100
,
0.1
),
np
.
arange
(
100
,
200
,
0.1
)]
all_
time
=
[
np
.
arange
(
0
,
100
,
0.1
),
np
.
arange
(
100
,
200
,
0.1
)]
flux
=
[
np
.
sin
(
t
)
for
t
in
time
]
all_
flux
=
[
np
.
sin
(
t
)
for
t
in
all_
time
]
# Logarithmically sample candidate break point spacings.
# Logarithmically sample candidate break point spacings.
bkspaces
=
np
.
logspace
(
np
.
log10
(
0.5
),
np
.
log10
(
5
),
num
=
20
)
bkspaces
=
np
.
logspace
(
np
.
log10
(
0.5
),
np
.
log10
(
5
),
num
=
20
)
...
@@ -83,29 +149,38 @@ class KeplerSplineTest(absltest.TestCase):
...
@@ -83,29 +149,38 @@ class KeplerSplineTest(absltest.TestCase):
return
np
.
sqrt
(
np
.
mean
((
f
-
s
)
**
2
))
return
np
.
sqrt
(
np
.
mean
((
f
-
s
)
**
2
))
# Penalty coefficient 1.0.
# Penalty coefficient 1.0.
spline
,
mask
,
bkspace
,
bad_bkspaces
=
kepler_spline
.
choose_kepler_spline
(
spline
,
metadata
=
kepler_spline
.
choose_kepler_spline
(
time
,
flux
,
bkspaces
,
penalty_coeff
=
1.0
)
all_time
,
all_flux
,
bkspaces
,
penalty_coeff
=
1.0
)
self
.
assertAlmostEqual
(
_rmse
(
flux
,
spline
),
0.013013
)
self
.
assertAlmostEqual
(
_rmse
(
all_flux
,
spline
),
0.013013
)
self
.
assertTrue
(
np
.
all
(
mask
))
self
.
assertTrue
(
np
.
all
(
metadata
.
light_curve_mask
))
self
.
assertAlmostEqual
(
bkspace
,
1.67990914314
)
self
.
assertAlmostEqual
(
metadata
.
bkspace
,
1.67990914314
)
self
.
assertEmpty
(
bad_bkspaces
)
self
.
assertEmpty
(
metadata
.
bad_bkspaces
)
self
.
assertAlmostEqual
(
metadata
.
likelihood_term
,
-
6685.64217856480
)
self
.
assertAlmostEqual
(
metadata
.
penalty_term
,
942.51190498322
)
self
.
assertAlmostEqual
(
metadata
.
bic
,
-
5743.13027358158
)
# Decrease penalty coefficient; allow smaller spacing for closer fit.
# Decrease penalty coefficient; allow smaller spacing for closer fit.
spline
,
mask
,
bkspace
,
bad_bkspaces
=
kepler_spline
.
choose_kepler_spline
(
spline
,
metadata
=
kepler_spline
.
choose_kepler_spline
(
time
,
flux
,
bkspaces
,
penalty_coeff
=
0.1
)
all_time
,
all_flux
,
bkspaces
,
penalty_coeff
=
0.1
)
self
.
assertAlmostEqual
(
_rmse
(
flux
,
spline
),
0.0066376
)
self
.
assertAlmostEqual
(
_rmse
(
all_flux
,
spline
),
0.0066376
)
self
.
assertTrue
(
np
.
all
(
mask
))
self
.
assertTrue
(
np
.
all
(
metadata
.
light_curve_mask
))
self
.
assertAlmostEqual
(
bkspace
,
1.48817572082
)
self
.
assertAlmostEqual
(
metadata
.
bkspace
,
1.48817572082
)
self
.
assertEmpty
(
bad_bkspaces
)
self
.
assertEmpty
(
metadata
.
bad_bkspaces
)
self
.
assertAlmostEqual
(
metadata
.
likelihood_term
,
-
6731.59913975551
)
self
.
assertAlmostEqual
(
metadata
.
penalty_term
,
1064.12634433589
)
self
.
assertAlmostEqual
(
metadata
.
bic
,
-
6625.18650532192
)
# Increase penalty coefficient; require larger spacing at the cost of worse
# Increase penalty coefficient; require larger spacing at the cost of worse
# fit.
# fit.
spline
,
mask
,
bkspace
,
bad_bkspaces
=
kepler_spline
.
choose_kepler_spline
(
spline
,
metadata
=
kepler_spline
.
choose_kepler_spline
(
time
,
flux
,
bkspaces
,
penalty_coeff
=
2
)
all_time
,
all_flux
,
bkspaces
,
penalty_coeff
=
2
)
self
.
assertAlmostEqual
(
_rmse
(
flux
,
spline
),
0.026215449
)
self
.
assertAlmostEqual
(
_rmse
(
all_flux
,
spline
),
0.026215449
)
self
.
assertTrue
(
np
.
all
(
mask
))
self
.
assertTrue
(
np
.
all
(
metadata
.
light_curve_mask
))
self
.
assertAlmostEqual
(
bkspace
,
1.89634509537
)
self
.
assertAlmostEqual
(
metadata
.
bkspace
,
1.89634509537
)
self
.
assertEmpty
(
bad_bkspaces
)
self
.
assertEmpty
(
metadata
.
bad_bkspaces
)
self
.
assertAlmostEqual
(
metadata
.
likelihood_term
,
-
6495.65564287904
)
self
.
assertAlmostEqual
(
metadata
.
penalty_term
,
836.099270549629
)
self
.
assertAlmostEqual
(
metadata
.
bic
,
-
4823.45710177978
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment