Unverified Commit 2c181308 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5862 from cshallue/master

Move tensorflow_models/research/astronet to google-research/exoplanet-ml
parents caafb6d1 62704f06
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility function for smoothing data using a median filter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
"""Computes the median y-value in uniform intervals (bins) along the x-axis.
The interval [x_min, x_max) is divided into num_bins uniformly spaced
intervals of width bin_width. The value computed for each bin is the median
of all y-values whose corresponding x-value is in the interval.
NOTE: x must be sorted in ascending order or the results will be incorrect.
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
spaced bins on the x-axis.
Raises:
ValueError: If an argument has an inappropriate value.
"""
if num_bins < 2:
raise ValueError("num_bins must be at least 2. Got: {}".format(num_bins))
# Validate the lengths of x and y.
x_len = len(x)
if x_len < 2:
raise ValueError("len(x) must be at least 2. Got: {}".format(x_len))
if x_len != len(y):
raise ValueError("len(x) (got: {}) must equal len(y) (got: {})".format(
x_len, len(y)))
# Validate x_min and x_max.
x_min = x_min if x_min is not None else x[0]
x_max = x_max if x_max is not None else x[-1]
if x_min >= x_max:
raise ValueError("x_min (got: {}) must be less than x_max (got: {})".format(
x_min, x_max))
if x_min > x[-1]:
raise ValueError(
"x_min (got: {}) must be less than or equal to the largest value of x "
"(got: {})".format(x_min, x[-1]))
# Validate bin_width.
bin_width = bin_width if bin_width is not None else (x_max - x_min) / num_bins
if bin_width <= 0:
raise ValueError("bin_width must be positive. Got: {}".format(bin_width))
if bin_width >= x_max - x_min:
raise ValueError(
"bin_width (got: {}) must be less than x_max - x_min (got: {})".format(
bin_width, x_max - x_min))
bin_spacing = (x_max - x_min - bin_width) / (num_bins - 1)
# Bins with no y-values will fall back to the global median.
result = np.repeat(np.median(y), num_bins)
# Find the first element of x >= x_min. This loop is guaranteed to produce
# a valid index because we know that x_min <= x[-1].
x_start = 0
while x[x_start] < x_min:
x_start += 1
# The bin at index i is the median of all elements y[j] such that
# bin_min <= x[j] < bin_max, where bin_min and bin_max are the endpoints of
# bin i.
bin_min = x_min # Left endpoint of the current bin.
bin_max = x_min + bin_width # Right endpoint of the current bin.
j_start = x_start # Inclusive left index of the current bin.
j_end = x_start # Exclusive end index of the current bin.
for i in range(num_bins):
# Move j_start to the first index of x >= bin_min.
while j_start < x_len and x[j_start] < bin_min:
j_start += 1
# Move j_end to the first index of x >= bin_max (exclusive end index).
while j_end < x_len and x[j_end] < bin_max:
j_end += 1
if j_end > j_start:
# Compute and insert the median bin value.
result[i] = np.median(y[j_start:j_end])
# Advance the bin.
bin_min += bin_spacing
bin_max += bin_spacing
return result
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for median_filter.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve import median_filter
class MedianFilterTest(absltest.TestCase):
def testErrors(self):
# x size less than 2.
x = [1]
y = [2]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
# x and y not the same size.
x = [1, 2]
y = [4, 5, 6]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
# x_min not less than x_max.
x = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=-1, x_max=-1)
# x_min greater than the last element of x.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=0.25, x_min=3.5, x_max=4)
# bin_width nonpositive.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=0, x_min=1, x_max=3)
# bin_width greater than or equal to x_max - x_min.
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=1.5, x_max=2.5)
# num_bins less than 2.
x = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=1, bin_width=1, x_min=0, x_max=2)
def testBucketBoundaries(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([2.5, 4.5, 6.5, 8.5, 10.5], result)
def testMultiSizeBins(self):
# Construct bins with size 0, 1, 2, 3, 4, 5, 10, respectively.
x = np.array([
1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6
])
y = np.array([
0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10
])
result = median_filter.median_filter(
x, y, num_bins=7, bin_width=1, x_min=0, x_max=7)
np.testing.assert_array_equal([3, 0, 0, 5, 3, 1, 5.5], result)
def testMedian(self):
x = np.array([-4, -2, -2, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3])
y = np.array([0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([0, 0, 5, 3, 1], result)
def testWideBins(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=6, x_min=-7, x_max=7)
np.testing.assert_array_equal([3, 4.5, 6.5, 8.5, 10.5], result)
def testNarrowBins(self):
x = np.array([-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=1, x_min=-4.5, x_max=4.5)
np.testing.assert_array_equal([3, 5, 7, 9, 11], result)
def testEmptyBins(self):
x = np.array([-1, 0, 1])
y = np.array([1, 2, 3])
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
np.testing.assert_array_equal([2, 2, 1.5, 3, 2], result)
def testDefaultArgs(self):
x = np.array([-4, -2, -2, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 3])
y = np.array([7, -1, 3, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1, 1, -1])
result = median_filter.median_filter(x, y, num_bins=5)
np.testing.assert_array_equal([7, 1, 5, 2, 3], result)
if __name__ == "__main__":
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Event class, which represents a periodic event in a light curve."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class Event(object):
"""Represents a periodic event in a light curve."""
def __init__(self, period, duration, t0):
"""Initializes the Event.
Args:
period: Period of the event, in days.
duration: Duration of the event, in days.
t0: Time of the first occurrence of the event, in days.
"""
self._period = period
self._duration = duration
self._t0 = t0
def __str__(self):
return "<period={}, duration={}, t0={}>".format(self.period, self.duration,
self.t0)
def __repr__(self):
return "Event({})".format(str(self))
@property
def period(self):
return self._period
@property
def duration(self):
return self._duration
@property
def t0(self):
return self._t0
def equals(self, other_event, period_rtol=0.001, t0_durations=1):
"""Compares this Event to another Event, within the given tolerance.
Args:
other_event: An Event.
period_rtol: Relative tolerance in matching the periods.
t0_durations: Tolerance in matching the t0 values, in units of the other
Event's duration.
Returns:
True if this Event is the same as other_event, within the given tolerance.
"""
# First compare the periods.
period_match = np.isclose(
self.period, other_event.period, rtol=period_rtol, atol=1e-8)
if not period_match:
return False
# To compare t0, we must consider that self.t0 and other_event.t0 may be at
# different phases. Just comparing mod(self.t0, period) to
# mod(other_event.t0, period) does not work because two similar values could
# end up at different ends of [0, period).
#
# Define t0_diff to be the absolute difference, up to multiples of period.
# This value is always in [0, period/2).
t0_diff = np.mod(self.t0 - other_event.t0, other_event.period)
if t0_diff > other_event.period / 2:
t0_diff = other_event.period - t0_diff
return t0_diff < t0_durations * other_event.duration
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for periodic_event.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from light_curve.periodic_event import Event
class EventTest(absltest.TestCase):
def testStr(self):
self.assertEqual(str(Event(1, 2, 3)), "<period=1, duration=2, t0=3>")
def testRepr(self):
self.assertEqual(
repr(Event(1, 2, 3)), "Event(<period=1, duration=2, t0=3>)")
def testEquals(self):
event = Event(period=100, duration=5, t0=2)
# Varying periods.
self.assertFalse(event.equals(Event(period=0, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=50, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=99.89, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=99.91, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100.01, duration=5, t0=2)))
self.assertFalse(event.equals(Event(period=101, duration=5, t0=2)))
# Different period tolerance.
self.assertTrue(
event.equals(Event(period=99.1, duration=5, t0=2), period_rtol=0.01))
self.assertTrue(
event.equals(Event(period=100.9, duration=5, t0=2), period_rtol=0.01))
self.assertFalse(
event.equals(Event(period=98.9, duration=5, t0=2), period_rtol=0.01))
self.assertFalse(
event.equals(Event(period=101.1, duration=5, t0=2), period_rtol=0.01))
# Varying t0.
self.assertTrue(event.equals(Event(period=100, duration=5, t0=0)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=2)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=6.9)))
self.assertFalse(event.equals(Event(period=100, duration=5, t0=7.1)))
# t0 at the other end of [0, period).
self.assertFalse(event.equals(Event(period=100, duration=5, t0=96.9)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=97.1)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=100)))
self.assertTrue(event.equals(Event(period=100, duration=5, t0=102)))
self.assertFalse(event.equals(Event(period=100, duration=5, t0=107.1)))
# Varying duration.
self.assertFalse(event.equals(Event(period=100, duration=5, t0=10)))
self.assertFalse(event.equals(Event(period=100, duration=7, t0=10)))
self.assertTrue(event.equals(Event(period=100, duration=9, t0=10)))
# Different duration tolerance.
self.assertFalse(
event.equals(Event(period=100, duration=5, t0=10), t0_durations=1))
self.assertTrue(
event.equals(Event(period=100, duration=5, t0=10), t0_durations=2))
if __name__ == "__main__":
absltest.main()
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Light curve utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import scipy.interpolate
from six.moves import range # pylint:disable=redefined-builtin
def phase_fold_time(time, period, t0):
"""Creates a phase-folded time vector.
result[i] is the unique number in [-period / 2, period / 2)
such that result[i] = time[i] - t0 + k_i * period, for some integer k_i.
Args:
time: 1D numpy array of time values.
period: A positive real scalar; the period to fold over.
t0: The center of the resulting folded vector; this value is mapped to 0.
Returns:
A 1D numpy array.
"""
half_period = period / 2
result = np.mod(time + (half_period - t0), period)
result -= half_period
return result
def split(all_time, all_flux, gap_width=0.75):
"""Splits a light curve on discontinuities (gaps).
This function accepts a light curve that is either a single segment, or is
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or sequence of numpy arrays; each is a sequence of
time values.
all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
flux values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split.
Returns:
out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
# Handle single-segment inputs.
if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
out_time = []
out_flux = []
for time, flux in zip(all_time, all_flux):
start = 0
for end in range(1, len(time) + 1):
# Choose the largest endpoint such that time[start:end] has no gaps.
if end == len(time) or time[end] - time[end - 1] > gap_width:
out_time.append(time[start:end])
out_flux.append(flux[start:end])
start = end
return out_time, out_flux
def remove_events(all_time,
all_flux,
events,
width_factor=1.0,
include_empty_segments=True):
"""Removes events from a light curve.
This function accepts either a single-segment or piecewise-defined light
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args:
all_time: Numpy array or sequence of numpy arrays; each is a sequence of
time values.
all_flux: Numpy array or sequence of numpy arrays; each is a sequence of
flux values of the corresponding time array.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
include_empty_segments: Whether to include empty segments in the output.
Returns:
output_time: Numpy array or list of numpy arrays; the time arrays with
events removed.
output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed.
"""
# Handle single-segment inputs.
if isinstance(all_time, np.ndarray) and all_time.ndim == 1:
all_time = [all_time]
all_flux = [all_flux]
single_segment = True
else:
single_segment = False
output_time = []
output_flux = []
for time, flux in zip(all_time, all_flux):
mask = np.ones_like(time, dtype=np.bool)
for event in events:
transit_dist = np.abs(phase_fold_time(time, event.period, event.t0))
mask = np.logical_and(mask,
transit_dist > 0.5 * width_factor * event.duration)
if single_segment:
output_time = time[mask]
output_flux = flux[mask]
elif include_empty_segments or np.any(mask):
output_time.append(time[mask])
output_flux.append(flux[mask])
return output_time, output_flux
def interpolate_missing_time(time, cadence_no=None, fill_value="extrapolate"):
"""Interpolates missing (NaN or Inf) time values.
Args:
time: A numpy array of monotonically increasing values, with missing values
denoted by NaN or Inf.
cadence_no: Optional numpy array of cadence numbers corresponding to the
time values. If not provided, missing time values are assumed to be evenly
spaced between present time values.
fill_value: Specifies how missing time values should be treated at the
beginning and end of the array. See scipy.interpolate.interp1d.
Returns:
A numpy array of the same length as the input time array, with NaN/Inf
values replaced with interpolated values.
Raises:
ValueError: If fewer than 2 values of time are finite.
"""
if cadence_no is None:
cadence_no = np.arange(len(time))
is_finite = np.isfinite(time)
num_finite = np.sum(is_finite)
if num_finite < 2:
raise ValueError(
"Cannot interpolate time with fewer than 2 finite values. Got "
"len(time) = {} with {} finite values.".format(len(time), num_finite))
interpolate_fn = scipy.interpolate.interp1d(
cadence_no[is_finite],
time[is_finite],
copy=False,
bounds_error=False,
fill_value=fill_value,
assume_sorted=True)
return interpolate_fn(cadence_no)
def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
"""Linearly interpolates spline values across masked points.
Args:
all_time: List of numpy arrays; each is a sequence of time values.
all_masked_time: List of numpy arrays; each is a sequence of time values
with some values missing (masked).
all_masked_spline: List of numpy arrays; the masked spline values
corresponding to all_masked_time.
Returns:
interp_spline: List of numpy arrays; each is the masked spline with missing
points linearly interpolated.
"""
interp_spline = []
for time, masked_time, masked_spline in zip(all_time, all_masked_time,
all_masked_spline):
if masked_time.size:
interp_spline.append(np.interp(time, masked_time, masked_spline))
else:
interp_spline.append(np.array([np.nan] * len(time)))
return interp_spline
def reshard_arrays(xs, ys):
"""Reshards arrays in xs to match the lengths of arrays in ys.
Args:
xs: List of 1d numpy arrays with the same total length as ys.
ys: List of 1d numpy arrays with the same total length as xs.
Returns:
A list of numpy arrays containing the same elements as xs, in the same
order, but with array lengths matching the pairwise array in ys.
Raises:
ValueError: If xs and ys do not have the same total length.
"""
# Compute indices of boundaries between segments of ys, plus the end boundary.
boundaries = np.cumsum([len(y) for y in ys])
concat_x = np.concatenate(xs)
if len(concat_x) != boundaries[-1]:
raise ValueError(
"xs and ys do not have the same total length ({} vs. {}).".format(
len(concat_x), boundaries[-1]))
boundaries = boundaries[:-1] # Remove exclusive end boundary.
return np.split(concat_x, boundaries)
def uniform_cadence_light_curve(cadence_no, time, flux):
"""Combines data into a single light curve with uniform cadence numbers.
Args:
cadence_no: numpy array; the cadence numbers of the light curve.
time: numpy array; the time values of the light curve.
flux: numpy array; the flux values of the light curve.
Returns:
cadence_no: numpy array; the cadence numbers of the light curve with no
gaps. It starts and ends at the minimum and maximum cadence numbers in the
input light curve, respectively.
time: numpy array; the time values of the light curve. Missing data points
have value zero and correspond to a False value in the mask.
flux: numpy array; the time values of the light curve. Missing data points
have value zero and correspond to a False value in the mask.
mask: Boolean numpy array; False indicates missing data points, where
missing data points are those that have no corresponding cadence number in
the input or those where at least one of the cadence number, time value,
or flux value is NaN/Inf.
Raises:
ValueError: If there are duplicate cadence numbers in the input.
"""
min_cadence_no = np.min(cadence_no)
max_cadence_no = np.max(cadence_no)
out_cadence_no = np.arange(
min_cadence_no, max_cadence_no + 1, dtype=cadence_no.dtype)
out_time = np.zeros_like(out_cadence_no, dtype=time.dtype)
out_flux = np.zeros_like(out_cadence_no, dtype=flux.dtype)
out_mask = np.zeros_like(out_cadence_no, dtype=np.bool)
for c, t, f in zip(cadence_no, time, flux):
if np.isfinite(c) and np.isfinite(t) and np.isfinite(f):
i = int(c - min_cadence_no)
if out_mask[i]:
raise ValueError("Duplicate cadence number: {}".format(c))
out_time[i] = t
out_flux[i] = f
out_mask[i] = True
return out_cadence_no, out_time, out_flux, out_mask
def count_transit_points(time, event):
"""Computes the number of points in each transit of a given event.
Args:
time: Sorted numpy array of time values.
event: An Event object.
Returns:
A numpy array containing the number of time points "in transit" for each
transit occurring between the first and last time values.
Raises:
ValueError: If there are more than 10**6 transits.
"""
t_min = np.min(time)
t_max = np.max(time)
# Tiny periods or erroneous time values could make this loop take forever.
if (t_max - t_min) / event.period > 10**6:
raise ValueError(
"Too many transits! Time range is [{:.4f}, {:.4f}] and period is "
"{:.4e}.".format(t_min, t_max, event.period))
# Make sure t0 is in [t_min, t_min + period).
t0 = np.mod(event.t0 - t_min, event.period) + t_min
# Prepare loop variables.
points_in_transit = []
i, j = 0, 0
for transit_midpoint in np.arange(t0, t_max, event.period):
transit_begin = transit_midpoint - event.duration / 2
transit_end = transit_midpoint + event.duration / 2
# Move time[i] to the first point >= transit_begin.
while time[i] < transit_begin:
# transit_begin is guaranteed to be < np.max(t) (provided duration >= 0).
# Therefore, i cannot go out of range.
i += 1
# Move time[j] to the first point > transit_end.
while time[j] <= transit_end:
j += 1
# j went out of range. We're finished.
if j >= len(time):
break
# The points in the current transit duration are precisely time[i:j].
# Since j is an exclusive index, there are exactly j-i points in transit.
points_in_transit.append(j - i)
return np.array(points_in_transit)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve import periodic_event
from light_curve import util
class LightCurveUtilTest(absltest.TestCase):
def testPhaseFoldTime(self):
time = np.arange(0, 2, 0.1)
# Simple.
tfold = util.phase_fold_time(time, period=1, t0=0.45)
expected = [
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45,
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45
]
self.assertSequenceAlmostEqual(expected, tfold)
# Large t0.
tfold = util.phase_fold_time(time, period=1, t0=1.25)
expected = [
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35, -0.25,
-0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35
]
self.assertSequenceAlmostEqual(expected, tfold)
# Negative t0.
tfold = util.phase_fold_time(time, period=1, t0=-1.65)
expected = [
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35,
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45
]
self.assertSequenceAlmostEqual(expected, tfold)
# Negative time.
time = np.arange(-3, -1, 0.1)
tfold = util.phase_fold_time(time, period=1, t0=0.55)
expected = [
0.45, -0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45,
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35
]
self.assertSequenceAlmostEqual(expected, tfold)
def testSplit(self):
# Single segment.
all_time = np.concatenate([np.arange(0, 1, 0.1), np.arange(1.5, 2, 0.1)])
all_flux = np.ones(15)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 2)
self.assertLen(split_flux, 2)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
# Multi segment.
all_time = [
np.concatenate([
np.arange(0, 1, 0.1),
np.arange(1.5, 2, 0.1),
np.arange(3, 4, 0.1)
]),
np.arange(4, 5, 0.1)
]
all_flux = [np.ones(25), np.ones(10)]
self.assertLen(all_time, 2)
self.assertLen(all_time[0], 25)
self.assertLen(all_time[1], 10)
self.assertLen(all_flux, 2)
self.assertLen(all_flux[0], 25)
self.assertLen(all_flux[1], 10)
# Gap width 0.5.
split_time, split_flux = util.split(all_time, all_flux, gap_width=0.5)
self.assertLen(split_time, 4)
self.assertLen(split_flux, 4)
self.assertSequenceAlmostEqual(np.arange(0, 1, 0.1), split_time[0])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(1.5, 2, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(5), split_flux[1])
self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[2])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[3])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[3])
# Gap width 1.0.
split_time, split_flux = util.split(all_time, all_flux, gap_width=1)
self.assertLen(split_time, 3)
self.assertLen(split_flux, 3)
self.assertSequenceAlmostEqual([
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.5, 1.6, 1.7, 1.8, 1.9
], split_time[0])
self.assertSequenceAlmostEqual(np.ones(15), split_flux[0])
self.assertSequenceAlmostEqual(np.arange(3, 4, 0.1), split_time[1])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[1])
self.assertSequenceAlmostEqual(np.arange(4, 5, 0.1), split_time[2])
self.assertSequenceAlmostEqual(np.ones(10), split_flux[2])
def testRemoveEvents(self):
time = np.arange(20, dtype=np.float)
flux = 10 * time
# One event.
events = [periodic_event.Event(period=4, duration=1.5, t0=3.5)]
output_time, output_flux = util.remove_events(time, flux, events)
self.assertSequenceAlmostEqual([1, 2, 5, 6, 9, 10, 13, 14, 17, 18],
output_time)
self.assertSequenceAlmostEqual(
[10, 20, 50, 60, 90, 100, 130, 140, 170, 180], output_flux)
# Two events.
events.append(periodic_event.Event(period=7, duration=1.5, t0=6.5))
output_time, output_flux = util.remove_events(time, flux, events)
self.assertSequenceAlmostEqual([1, 2, 5, 9, 10, 17, 18], output_time)
self.assertSequenceAlmostEqual([10, 20, 50, 90, 100, 170, 180], output_flux)
# Multi segment light curve.
time = [np.arange(10, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
output_time, output_flux = util.remove_events(time, flux, events)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceAlmostEqual([1, 2, 5, 9], output_time[0])
self.assertSequenceAlmostEqual([10, 20, 50, 90], output_flux[0])
self.assertSequenceAlmostEqual([10, 17, 18], output_time[1])
self.assertSequenceAlmostEqual([100, 170, 180], output_flux[1])
# One segment totally removed with include_empty_segments = True.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=True)
self.assertLen(output_time, 2)
self.assertLen(output_flux, 2)
self.assertSequenceEqual([], output_time[0])
self.assertSequenceEqual([], output_flux[0])
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[1])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[1])
# One segment totally removed with include_empty_segments = False.
time = [np.arange(5, dtype=np.float), np.arange(10, 20, dtype=np.float)]
flux = [10 * t for t in time]
events = [periodic_event.Event(period=10, duration=2, t0=2.5)]
output_time, output_flux = util.remove_events(
time, flux, events, width_factor=3, include_empty_segments=False)
self.assertLen(output_time, 1)
self.assertLen(output_flux, 1)
self.assertSequenceAlmostEqual([16, 17, 18, 19], output_time[0])
self.assertSequenceAlmostEqual([160, 170, 180, 190], output_flux[0])
def testInterpolateMissingTime(self):
# Fewer than 2 finite values.
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([5.0]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([5.0, np.nan]))
with self.assertRaises(ValueError):
util.interpolate_missing_time(np.array([np.nan, np.nan, np.nan]))
# Small time arrays.
self.assertSequenceAlmostEqual([0.5, 0.6],
util.interpolate_missing_time(
np.array([0.5, 0.6])))
self.assertSequenceAlmostEqual([0.5, 0.6, 0.7],
util.interpolate_missing_time(
np.array([0.5, np.nan, 0.7])))
# Time array of length 20 with some values NaN.
time = np.array([
np.nan, 0.5, 1.0, 1.5, 2.0, 2.5, np.nan, 3.5, 4.0, 4.5, 5.0, np.nan,
np.nan, np.nan, np.nan, 7.5, 8.0, 8.5, np.nan, np.nan
])
interp_time = util.interpolate_missing_time(time)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 9.0, 9.5
], interp_time)
# Fill with 0.0 for missing values at the beginning and end.
interp_time = util.interpolate_missing_time(time, fill_value=0.0)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 0.0, 0.0
], interp_time)
# Interpolate with cadences.
cadences = np.array([
100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
114, 115, 116, 117, 118, 119
])
interp_time = util.interpolate_missing_time(time, cadences)
self.assertSequenceAlmostEqual([
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5,
7.0, 7.5, 8.0, 8.5, 9.0, 9.5
], interp_time)
# Interpolate with missing cadences.
time = np.array([0.6, 0.7, np.nan, np.nan, np.nan, 1.3, 1.4, 1.5])
cadences = np.array([106, 107, 108, 109, 110, 113, 114, 115])
interp_time = util.interpolate_missing_time(time, cadences)
self.assertSequenceAlmostEqual([0.6, 0.7, 0.8, 0.9, 1.0, 1.3, 1.4, 1.5],
interp_time)
def testInterpolateMaskedSpline(self):
all_time = [
np.arange(0, 10, dtype=np.float),
np.arange(10, 20, dtype=np.float),
np.arange(20, 30, dtype=np.float),
]
all_masked_time = [
np.array([0, 1, 2, 3, 8, 9], dtype=np.float), # No 4, 5, 6, 7
np.array([10, 11, 12, 13, 14, 15, 16], dtype=np.float), # No 17, 18, 19
np.array([], dtype=np.float)
]
all_masked_spline = [2 * t + 100 for t in all_masked_time]
interp_spline = util.interpolate_masked_spline(all_time, all_masked_time,
all_masked_spline)
self.assertLen(interp_spline, 3)
self.assertSequenceAlmostEqual(
[100, 102, 104, 106, 108, 110, 112, 114, 116, 118], interp_spline[0])
self.assertSequenceAlmostEqual(
[120, 122, 124, 126, 128, 130, 132, 132, 132, 132], interp_spline[1])
self.assertTrue(np.all(np.isnan(interp_spline[2])))
def testReshardArrays(self):
xs = [
np.array([1, 2, 3]),
np.array([4]),
np.array([5, 6, 7, 8, 9]),
np.array([]),
]
ys = [
np.array([]),
np.array([10, 20]),
np.array([30, 40, 50, 60]),
np.array([70]),
np.array([80, 90]),
]
reshard_xs = util.reshard_arrays(xs, ys)
self.assertLen(reshard_xs, 5)
np.testing.assert_array_equal([], reshard_xs[0])
np.testing.assert_array_equal([1, 2], reshard_xs[1])
np.testing.assert_array_equal([3, 4, 5, 6], reshard_xs[2])
np.testing.assert_array_equal([7], reshard_xs[3])
np.testing.assert_array_equal([8, 9], reshard_xs[4])
with self.assertRaisesRegexp(ValueError,
"xs and ys do not have the same total length"):
util.reshard_arrays(xs, [np.array([10, 20, 30]), np.array([40, 50])])
def testUniformCadenceLightCurve(self):
input_cadence_no = np.array([13, 4, 5, 6, 8, 9, 11, 12])
input_time = np.array([130, 40, 50, 60, 80, 90, 110, 120])
input_flux = np.array([1300, 400, 500, 600, 800, np.nan, 1100, 1200])
cadence_no, time, flux, mask = util.uniform_cadence_light_curve(
input_cadence_no, input_time, input_flux)
np.testing.assert_array_equal([4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
cadence_no)
np.testing.assert_array_equal([40, 50, 60, 0, 80, 0, 0, 110, 120, 130],
time)
np.testing.assert_array_equal(
[400, 500, 600, 0, 800, 0, 0, 1100, 1200, 1300], flux)
np.testing.assert_array_equal([1, 1, 1, 0, 1, 0, 0, 1, 1, 1], mask)
# Add duplicate cadence number.
input_cadence_no = np.concatenate([input_cadence_no, np.array([13, 14])])
input_time = np.concatenate([input_time, np.array([130, 140])])
input_flux = np.concatenate([input_flux, np.array([1300, 1400])])
with self.assertRaisesRegexp(ValueError, "Duplicate cadence number"):
util.uniform_cadence_light_curve(input_cadence_no, input_time, input_flux)
def testCountTransitPoints(self):
time = np.concatenate([
np.arange(0, 10, 0.1, dtype=np.float),
np.arange(15, 30, 0.1, dtype=np.float),
np.arange(50, 100, 0.1, dtype=np.float)
])
event = periodic_event.Event(period=10, duration=5, t0=9.95)
points_in_transit = util.count_transit_points(time, event)
np.testing.assert_array_equal([25, 50, 25, 0, 25, 50, 50, 50, 50],
points_in_transit)
if __name__ == "__main__":
absltest.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_runner",
srcs = ["estimator_runner.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
import tensorflow as tf
def parse_json(json_string_or_file):
"""Parses values from a JSON string or JSON file.
This function is useful for command line flags containing configuration
overrides. Using this function, the flag can be passed either as a JSON string
(e.g. '{"learning_rate": 1.0}') or the path to a JSON configuration file.
Args:
json_string_or_file: A JSON serialized string OR the path to a JSON file.
Returns:
A dictionary; the parsed JSON.
Raises:
ValueError: If the JSON could not be parsed.
"""
# First, attempt to parse the string as a JSON dict.
try:
json_dict = json.loads(json_string_or_file)
except ValueError as literal_json_parsing_error:
try:
# Otherwise, try to use it as a path to a JSON file.
with tf.gfile.Open(json_string_or_file) as f:
json_dict = json.load(f)
except ValueError as json_file_parsing_error:
raise ValueError("Unable to parse the content of the json file {}. "
"Parsing error: {}.".format(
json_string_or_file,
json_file_parsing_error.message))
except tf.gfile.FileError:
message = ("Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.\n"
"JSON parsing error: {}\n\n Input parameter:\n{}.".format(
literal_json_parsing_error.message, json_string_or_file))
raise ValueError(message)
return json_dict
def to_json(config):
"""Converts a JSON-serializable configuration object to a JSON string."""
if hasattr(config, "to_json") and callable(config.to_json):
return config.to_json(indent=2)
else:
return json.dumps(config, indent=2)
def log_and_save_config(config, output_dir):
"""Logs and writes a JSON-serializable configuration object.
Args:
config: A JSON-serializable object.
output_dir: Destination directory.
"""
config_json = to_json(config)
tf.logging.info("config: %s", config_json)
tf.gfile.MakeDirs(output_dir)
with tf.gfile.Open(os.path.join(output_dir, "config.json"), "w") as f:
f.write(config_json)
def unflatten(flat_config):
"""Transforms a flat configuration dictionary into a nested dictionary.
Example:
{
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.d.f": 4,
}
would be transformed to:
{
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3,
"f": 4,
}
}
}
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
"""
config = {}
for path, value in flat_config.items():
path = path.split(".")
final_key = path.pop()
nested_config = config
for key in path:
nested_config = nested_config.setdefault(key, {})
nested_config[final_key] = value
return config
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tf_util import config_util
class ConfigUtilTest(tf.test.TestCase):
def testUnflatten(self):
# Empty dict.
self.assertDictEqual(config_util.unflatten({}), {})
# Already flat dict.
self.assertDictEqual(
config_util.unflatten({
"a": 1,
"b": 2
}), {
"a": 1,
"b": 2
})
# Nested dict.
self.assertDictEqual(
config_util.unflatten({
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.d.f": 4,
}), {
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3,
"f": 4,
}
}
})
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration container for TensorFlow models.
A ConfigDict is simply a dict whose values can be accessed via both dot syntax
(config.key) and dict syntax (config['key']).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def _maybe_convert_dict(value):
if isinstance(value, dict):
return ConfigDict(value)
return value
class ConfigDict(dict):
"""Configuration container class."""
def __init__(self, initial_dictionary=None):
"""Creates an instance of ConfigDict.
Args:
initial_dictionary: Optional dictionary or ConfigDict containing initial
parameters.
"""
if initial_dictionary:
for field, value in initial_dictionary.items():
initial_dictionary[field] = _maybe_convert_dict(value)
super(ConfigDict, self).__init__(initial_dictionary)
def __setattr__(self, attribute, value):
self[attribute] = _maybe_convert_dict(value)
def __getattr__(self, attribute):
try:
return self[attribute]
except KeyError as e:
raise AttributeError(e)
def __delattr__(self, attribute):
try:
del self[attribute]
except KeyError as e:
raise AttributeError(e)
def __setitem__(self, key, value):
super(ConfigDict, self).__setitem__(key, _maybe_convert_dict(value))
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.configdict."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from tf_util import configdict
class ConfigDictTest(absltest.TestCase):
def setUp(self):
super(ConfigDictTest, self).setUp()
self._config = configdict.ConfigDict({
"int": 1,
"float": 2.0,
"bool": True,
"str": "hello",
"nested": {
"int": 3,
},
"double_nested": {
"a": {
"int": 3,
},
"b": {
"float": 4.0,
}
}
})
def testAccess(self):
# Simple types.
self.assertEqual(1, self._config.int)
self.assertEqual(1, self._config["int"])
self.assertEqual(2.0, self._config.float)
self.assertEqual(2.0, self._config["float"])
self.assertTrue(self._config.bool)
self.assertTrue(self._config["bool"])
self.assertEqual("hello", self._config.str)
self.assertEqual("hello", self._config["str"])
# Single nested config.
self.assertEqual(3, self._config.nested.int)
self.assertEqual(3, self._config["nested"].int)
self.assertEqual(3, self._config.nested["int"])
self.assertEqual(3, self._config["nested"]["int"])
# Double nested config.
self.assertEqual(3, self._config["double_nested"].a.int)
self.assertEqual(3, self._config["double_nested"]["a"].int)
self.assertEqual(3, self._config["double_nested"].a["int"])
self.assertEqual(3, self._config["double_nested"]["a"]["int"])
self.assertEqual(4.0, self._config.double_nested.b.float)
self.assertEqual(4.0, self._config.double_nested["b"].float)
self.assertEqual(4.0, self._config.double_nested.b["float"])
self.assertEqual(4.0, self._config.double_nested["b"]["float"])
# Nonexistent parameters.
with self.assertRaises(AttributeError):
_ = self._config.nonexistent
with self.assertRaises(KeyError):
_ = self._config["nonexistent"]
def testSetAttribut(self):
# Overwrite existing simple type.
self._config.int = 40
self.assertEqual(40, self._config.int)
# Overwrite existing nested simple type.
self._config.nested.int = 40
self.assertEqual(40, self._config.nested.int)
# Overwrite existing nested config.
self._config.double_nested.a = {"float": 50.0}
self.assertIsInstance(self._config.double_nested.a, configdict.ConfigDict)
self.assertEqual(50.0, self._config.double_nested.a.float)
self.assertNotIn("int", self._config.double_nested.a)
# Set new simple type.
self._config.int_2 = 10
self.assertEqual(10, self._config.int_2)
# Set new nested simple type.
self._config.nested.int_2 = 20
self.assertEqual(20, self._config.nested.int_2)
# Set new nested config.
self._config.double_nested.c = {"int": 30}
self.assertIsInstance(self._config.double_nested.c, configdict.ConfigDict)
self.assertEqual(30, self._config.double_nested.c.int)
def testSetItem(self):
# Overwrite existing simple type.
self._config["int"] = 40
self.assertEqual(40, self._config.int)
# Overwrite existing nested simple type.
self._config["nested"].int = 40
self.assertEqual(40, self._config.nested.int)
self._config.nested["int"] = 50
self.assertEqual(50, self._config.nested.int)
# Overwrite existing nested config.
self._config.double_nested["a"] = {"float": 50.0}
self.assertIsInstance(self._config.double_nested.a, configdict.ConfigDict)
self.assertEqual(50.0, self._config.double_nested.a.float)
self.assertNotIn("int", self._config.double_nested.a)
# Set new simple type.
self._config["int_2"] = 10
self.assertEqual(10, self._config.int_2)
# Set new nested simple type.
self._config.nested["int_2"] = 20
self.assertEqual(20, self._config.nested.int_2)
self._config.nested["int_3"] = 30
self.assertEqual(30, self._config.nested.int_3)
# Set new nested config.
self._config.double_nested["c"] = {"int": 30}
self.assertIsInstance(self._config.double_nested.c, configdict.ConfigDict)
self.assertEqual(30, self._config.double_nested.c.int)
def testDelete(self):
# Simple types.
self.assertEqual(1, self._config.int)
del self._config.int
with self.assertRaises(AttributeError):
_ = self._config.int
with self.assertRaises(KeyError):
_ = self._config["int"]
self.assertEqual(2.0, self._config["float"])
del self._config["float"]
with self.assertRaises(AttributeError):
_ = self._config.float
with self.assertRaises(KeyError):
_ = self._config["float"]
# Nested config.
self.assertEqual(3, self._config.nested.int)
del self._config.nested
with self.assertRaises(AttributeError):
_ = self._config.nested
with self.assertRaises(KeyError):
_ = self._config["nested"]
if __name__ == "__main__":
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training and evaluation using a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def evaluate(estimator, eval_args):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
Returns:
global_step: The global step of the checkpoint evaluated.
values: A dict of metric values from the evaluation. May be empty, e.g. if
the training job has not yet saved a checkpoint or the checkpoint is
deleted by the time the TPU worker initializes.
"""
# Default return values if evaluation fails.
global_step = None
values = {}
latest_checkpoint = estimator.latest_checkpoint()
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
tf.logging.info("No checkpoint in %s, skipping evaluation.",
estimator.model_dir)
return global_step, values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
for eval_name, (input_fn, eval_steps) in eval_args.items():
values[eval_name] = estimator.evaluate(
input_fn, steps=eval_steps, name=eval_name)
if global_step is None:
global_step = values[eval_name].get("global_step")
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. checkpoint is already deleted by the
# trainer process. Increasing RunConfig.keep_checkpoint_max may prevent this
# in some cases.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation.",
latest_checkpoint)
return global_step, values
def continuous_eval(estimator,
eval_args,
train_steps=None,
timeout_secs=None,
timeout_fn=None):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training.
timeout_secs: Number of seconds to wait for new checkpoints. If None, wait
indefinitely.
timeout_fn: Optional function to call after timeout. The iterator will exit
if and only if the function returns True.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(
estimator.model_dir, timeout=timeout_secs, timeout_fn=timeout_fn):
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_args,
local_eval_frequency=None,
train_hooks=None,
train_steps=None):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_args: Dictionary of {eval_name: (input_fn, eval_steps)} where eval_name
is the name of the evaluation set (e.g. "train" or "val"), input_fn is an
input function returning a tuple (features, labels), and eval_steps is the
number of steps for which to evaluate the model (if None, evaluates until
input_fn raises an end-of-input exception).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
global_step, values = evaluate(estimator, eval_args)
yield global_step, values
global_step = global_step or 0 # Ensure global_step is not None.
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for getting and setting values in tf.Example protocol buffers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def get_feature(ex, name, kind=None, strict=True):
"""Gets a feature value from a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
A numpy array containing to the values of the specified feature.
Raises:
KeyError: If there is no feature with the specified name.
TypeError: If the feature has a different type to that specified.
"""
if name not in ex.features.feature:
if strict:
raise KeyError(name)
return np.array([])
inferred_kind = ex.features.feature[name].WhichOneof("kind")
if not inferred_kind:
return np.array([]) # Feature exists, but it's empty.
if kind and kind != inferred_kind:
raise TypeError("Requested {}, but Feature has {}".format(
kind, inferred_kind))
return np.array(getattr(ex.features.feature[name], inferred_kind).value)
def get_bytes_feature(ex, name, strict=True):
"""Gets the value of a bytes feature from a tf.train.Example."""
return get_feature(ex, name, "bytes_list", strict)
def get_float_feature(ex, name, strict=True):
"""Gets the value of a float feature from a tf.train.Example."""
return get_feature(ex, name, "float_list", strict)
def get_int64_feature(ex, name, strict=True):
"""Gets the value of an int64 feature from a tf.train.Example."""
return get_feature(ex, name, "int64_list", strict)
def _infer_kind(value):
"""Infers the tf.train.Feature kind from a value."""
if np.issubdtype(type(value[0]), np.integer):
return "int64_list"
try:
float(value[0])
return "float_list"
except ValueError:
return "bytes_list"
def set_feature(ex,
name,
value,
kind=None,
allow_overwrite=False,
bytes_encoding="latin-1"):
"""Sets a feature value in a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or
if `kind` is unrecognized.
"""
if name in ex.features.feature:
if allow_overwrite:
del ex.features.feature[name]
else:
raise ValueError(
"Attempting to overwrite feature with name: {}. "
"Set allow_overwrite=True if this is desired.".format(name))
if not kind:
kind = _infer_kind(value)
if kind == "bytes_list":
value = [str(v).encode(bytes_encoding) for v in value]
elif kind == "float_list":
value = [float(v) for v in value]
elif kind == "int64_list":
value = [int(v) for v in value]
else:
raise ValueError("Unrecognized kind: {}".format(kind))
getattr(ex.features.feature[name], kind).value.extend(value)
def set_float_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of a float feature in a tf.train.Example."""
set_feature(ex, name, value, "float_list", allow_overwrite)
def set_bytes_feature(ex,
name,
value,
allow_overwrite=False,
bytes_encoding="latin-1"):
"""Sets the value of a bytes feature in a tf.train.Example."""
set_feature(ex, name, value, "bytes_list", allow_overwrite, bytes_encoding)
def set_int64_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of an int64 feature in a tf.train.Example."""
set_feature(ex, name, value, "int64_list", allow_overwrite)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for example_util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tf_util import example_util
class ExampleUtilTest(tf.test.TestCase):
def test_get_feature(self):
# Create Example.
bytes_list = tf.train.BytesList(
value=[v.encode("latin-1") for v in ["a", "b", "c"]])
float_list = tf.train.FloatList(value=[1.0, 2.0, 3.0])
int64_list = tf.train.Int64List(value=[11, 22, 33])
ex = tf.train.Example(
features=tf.train.Features(
feature={
"a_bytes": tf.train.Feature(bytes_list=bytes_list),
"b_float": tf.train.Feature(float_list=float_list),
"c_int64": tf.train.Feature(int64_list=int64_list),
"d_empty": tf.train.Feature(),
}))
# Get bytes feature.
np.testing.assert_array_equal(
example_util.get_feature(ex, "a_bytes").astype(str), ["a", "b", "c"])
np.testing.assert_array_equal(
example_util.get_feature(ex, "a_bytes", "bytes_list").astype(str),
["a", "b", "c"])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "a_bytes").astype(str),
["a", "b", "c"])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "a_bytes", "float_list")
with self.assertRaises(TypeError):
example_util.get_float_feature(ex, "a_bytes")
with self.assertRaises(TypeError):
example_util.get_int64_feature(ex, "a_bytes")
# Get float feature.
np.testing.assert_array_almost_equal(
example_util.get_feature(ex, "b_float"), [1.0, 2.0, 3.0])
np.testing.assert_array_almost_equal(
example_util.get_feature(ex, "b_float", "float_list"), [1.0, 2.0, 3.0])
np.testing.assert_array_almost_equal(
example_util.get_float_feature(ex, "b_float"), [1.0, 2.0, 3.0])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "b_float", "int64_list")
with self.assertRaises(TypeError):
example_util.get_bytes_feature(ex, "b_float")
with self.assertRaises(TypeError):
example_util.get_int64_feature(ex, "b_float")
# Get int64 feature.
np.testing.assert_array_equal(
example_util.get_feature(ex, "c_int64"), [11, 22, 33])
np.testing.assert_array_equal(
example_util.get_feature(ex, "c_int64", "int64_list"), [11, 22, 33])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "c_int64"), [11, 22, 33])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "c_int64", "bytes_list")
with self.assertRaises(TypeError):
example_util.get_bytes_feature(ex, "c_int64")
with self.assertRaises(TypeError):
example_util.get_float_feature(ex, "c_int64")
# Get empty feature.
np.testing.assert_array_equal(example_util.get_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_feature(ex, "d_empty", "float_list"), [])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_float_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "d_empty"), [])
# Get nonexistent feature.
with self.assertRaises(KeyError):
example_util.get_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_feature(ex, "nonexistent", "bytes_list")
with self.assertRaises(KeyError):
example_util.get_bytes_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_float_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_int64_feature(ex, "nonexistent")
np.testing.assert_array_equal(
example_util.get_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_float_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "nonexistent", strict=False), [])
def test_set_feature(self):
ex = tf.train.Example()
# Set bytes features.
example_util.set_feature(ex, "a1_bytes", ["a", "b"])
example_util.set_feature(ex, "a2_bytes", ["A", "B"], kind="bytes_list")
example_util.set_bytes_feature(ex, "a3_bytes", ["x", "y"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a1_bytes"].bytes_list.value).astype(str),
["a", "b"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a2_bytes"].bytes_list.value).astype(str),
["A", "B"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a3_bytes"].bytes_list.value).astype(str),
["x", "y"])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "a3_bytes", ["xxx"]) # Duplicate.
# Set float features.
example_util.set_feature(ex, "b1_float", [1.0, 2.0])
example_util.set_feature(ex, "b2_float", [10.0, 20.0], kind="float_list")
example_util.set_float_feature(ex, "b3_float", [88.0, 99.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b1_float"].float_list.value, [1.0, 2.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b2_float"].float_list.value, [10.0, 20.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b3_float"].float_list.value, [88.0, 99.0])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "b3_float", [1234.0]) # Duplicate.
# Set int64 features.
example_util.set_feature(ex, "c1_int64", [1, 2, 3])
example_util.set_feature(ex, "c2_int64", [11, 22, 33], kind="int64_list")
example_util.set_int64_feature(ex, "c3_int64", [88, 99])
np.testing.assert_array_equal(
ex.features.feature["c1_int64"].int64_list.value, [1, 2, 3])
np.testing.assert_array_equal(
ex.features.feature["c2_int64"].int64_list.value, [11, 22, 33])
np.testing.assert_array_equal(
ex.features.feature["c3_int64"].int64_list.value, [88, 99])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "c3_int64", [1234]) # Duplicate.
# Overwrite features.
example_util.set_feature(ex, "a3_bytes", ["xxx"], allow_overwrite=True)
np.testing.assert_array_equal(
np.array(ex.features.feature["a3_bytes"].bytes_list.value).astype(str),
["xxx"])
example_util.set_feature(ex, "b3_float", [1234.0], allow_overwrite=True)
np.testing.assert_array_almost_equal(
ex.features.feature["b3_float"].float_list.value, [1234.0])
example_util.set_feature(ex, "c3_int64", [1234], allow_overwrite=True)
np.testing.assert_array_equal(
ex.features.feature["c3_int64"].int64_list.value, [1234])
if __name__ == "__main__":
tf.test.main()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # MIT
exports_files(["LICENSE"])
py_library(
name = "kepler_spline",
srcs = ["kepler_spline.py"],
srcs_version = "PY2AND3",
deps = ["//third_party/robust_mean"],
)
py_test(
name = "kepler_spline_test",
size = "small",
srcs = ["kepler_spline_test.py"],
srcs_version = "PY2AND3",
deps = [":kepler_spline"],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment