Unverified Commit 6571d16d authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3544 from cshallue/master

Add AstroNet to tensorflow/models
parents 92083555 6c891bc3
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_
#include <algorithm>
#include <iterator>
#include <vector>
namespace astronet {
// Computes the median value in the range [first, last).
//
// After calling this function, the elements in [first, last) will be rearranged
// such that, if middle = first + distance(first, last) / 2:
// 1. The element pointed at by middle is changed to whatever element would
// occur in that position if [first, last) was sorted.
// 2. All of the elements before this new middle element are less than or
// equal to the elements after the new nth element.
template <class RandomIt>
typename std::iterator_traits<RandomIt>::value_type InPlaceMedian(
RandomIt first, RandomIt last) {
// If n is odd, 'middle' points to the middle element. If n is even, 'middle'
// points to the upper middle element.
const auto n = std::distance(first, last);
const auto middle = first + (n / 2);
// Partially sort such that 'middle' in its place.
std::nth_element(first, middle, last);
// n is odd: the median is simply the middle element.
if (n & 1) {
return *middle;
}
// The maximum value lower than *middle is located in [first, middle) as a
// a post condition of nth_element.
const auto lower_middle = std::max_element(first, middle);
// Prevent overflow. We know that *lower_middle <= *middle. If both are on
// opposite sides of zero, the sum won't overflow, otherwise the difference
// won't overflow.
if (*lower_middle <= 0 && *middle >= 0) {
return (*lower_middle + *middle) / 2;
}
return *lower_middle + (*middle - *lower_middle) / 2;
}
// Computes the median value in the range [first, last) without modifying the
// input.
template <class ForwardIterator>
typename std::iterator_traits<ForwardIterator>::value_type Median(
ForwardIterator first, ForwardIterator last) {
std::vector<typename std::iterator_traits<ForwardIterator>::value_type>
values(first, last);
return InPlaceMedian(values.begin(), values.end());
}
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_H_
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/median_filter.h"
#include "absl/strings/substitute.h"
#include "light_curve_util/cc/median.h"
using absl::Substitute;
using std::min;
using std::vector;
namespace astronet {
bool MedianFilter(const vector<double>& x, const vector<double>& y,
int num_bins, double bin_width, double x_min, double x_max,
vector<double>* result, std::string* error) {
const std::size_t x_size = x.size();
if (x_size < 2) {
*error = Substitute("x.size() must be greater than 1. Got: $0", x_size);
return false;
}
if (x_size != y.size()) {
*error = Substitute("x.size() (got: $0) must equal y.size() (got: $1)",
x_size, y.size());
return false;
}
const double x_first = x[0];
const double x_last = x[x_size - 1];
if (x_first >= x_last) {
*error = Substitute(
"The first element of x (got: $0) must be less than the last "
"element (got: $1). Either x is not sorted or all elements are "
"equal.",
x_first, x_last);
return false;
}
if (x_min >= x_max) {
*error = Substitute("x_min (got: $0) must be less than x_max (got: $1)",
x_min, x_max);
return false;
}
if (x_min > x_last) {
*error = Substitute(
"x_min (got: $0) must be less than or equal to the largest value of x "
"(got: $1)",
x_min, x_last);
return false;
}
if (bin_width <= 0) {
*error = Substitute("bin_width must be positive. Got: $0", bin_width);
return false;
}
if (bin_width >= x_max - x_min) {
*error = Substitute(
"bin_width (got: $0) must be less than x_max - x_min (got: $1)",
bin_width, x_max - x_min);
return false;
}
if (num_bins < 2) {
*error = Substitute("num_bins must be greater than 1. Got: $0", num_bins);
return false;
}
result->resize(num_bins);
// Compute the spacing between midpoints of adjacent bins.
double bin_spacing = (x_max - x_min - bin_width) / (num_bins - 1);
// Create a vector to hold the values of the current bin on each iteration.
// Its initial size is twice the expected number of points per bin if x
// values are uniformly spaced. It will be expanded as necessary.
int points_per_bin =
1 + static_cast<int>(x_size * min(1.0, bin_width / (x_last - x_first)));
vector<double> bin_values(2 * points_per_bin);
// Create a vector to hold the indices of any empty bins.
vector<int> empty_bins;
// Find the first element of x >= x_min. This loop is guaranteed to produce
// a valid index because we know that x_min <= x_last.
int x_start = 0;
while (x[x_start] < x_min) ++x_start;
// The bin at index i is the median of all elements y[j] such that
// bin_min <= x[j] < bin_max, where bin_min and bin_max are the endpoints of
// bin i.
double bin_min = x_min; // Left endpoint of the current bin.
double bin_max = x_min + bin_width; // Right endpoint of the current bin.
int j_start = x_start; // Index of the first element in the current bin.
int j = x_start; // Index of the current element in the current bin.
for (int i = 0; i < num_bins; ++i) {
// Move j_start to the first index of x >= bin_min.
while (j_start < x_size && x[j_start] < bin_min) ++j_start;
// Accumulate values y[j] such that bin_min <= x[j] < bin_max. After this
// loop, j is the exclusive end index of the current bin.
j = j_start;
while (j < x_size && x[j] < bin_max) {
if (j - j_start >= bin_values.size()) {
bin_values.resize(2 * bin_values.size()); // Expand if necessary.
}
bin_values[j - j_start] = y[j];
++j;
}
int n = j - j_start; // Number of points in the bin.
if (n == 0) {
empty_bins.push_back(i); // Empty bin.
} else {
// Compute and insert the median bin value.
(*result)[i] = InPlaceMedian(bin_values.begin(), bin_values.begin() + n);
}
// Advance the bin.
bin_min += bin_spacing;
bin_max += bin_spacing;
}
// For empty bins, fall back to the median y value between x_min and x_max.
if (!empty_bins.empty()) {
double median = Median(y.begin() + x_start, y.begin() + j);
for (int i : empty_bins) {
(*result)[i] = median;
}
}
return true;
}
} // namespace astronet
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_
#include <iostream>
#include <string>
#include <vector>
namespace astronet {
// Computes the median y-value in uniform intervals (bins) along the x-axis.
//
// The interval [x_min, x_max) is divided into num_bins uniformly spaced
// intervals of width bin_width. The value computed for each bin is the median
// of all y-values whose corresponding x-value is in the interval.
//
// NOTE: x must be sorted in ascending order or the results will be incorrect.
//
// Input args:
// x: Vector of x-coordinates sorted in ascending order. Must have at least 2
// elements, and all elements cannot be the same value.
// y: Vector of y-coordinates with the same size as x.
// num_bins: The number of intervals to divide the x-axis into. Must be at
// least 2.
// bin_width: The width of each bin on the x-axis. Must be positive, and less
// than x_max - x_min.
// x_min: The inclusive leftmost value to consider on the x-axis. Must be less
// than or equal to the largest value of x.
// x_max: The exclusive rightmost value to consider on the x-axis. Must be
// greater than x_min.
//
// Output args:
// result: Vector of size num_bins containing the median y-values of uniformly
// spaced bins on the x-axis.
// error: String indicating an error (e.g. an invalid argument).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool MedianFilter(const std::vector<double>& x, const std::vector<double>& y,
int num_bins, double bin_width, double x_min, double x_max,
std::vector<double>* result, std::string* error);
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_MEDIAN_FILTER_H_
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/median_filter.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h"
using std::vector;
using testing::Pointwise;
namespace astronet {
namespace {
TEST(MedianFilter, Errors) {
vector<double> x;
vector<double> y;
vector<double> result;
std::string error;
// x size less than 2.
x = {1};
y = {2};
EXPECT_FALSE(MedianFilter(x, y, 2, 1, 0, 2, &result, &error));
EXPECT_EQ(error, "x.size() must be greater than 1. Got: 1");
// x and y not the same size.
x = {1, 2};
y = {4, 5, 6};
EXPECT_FALSE(MedianFilter(x, y, 2, 1, 0, 2, &result, &error));
EXPECT_EQ(error, "x.size() (got: 2) must equal y.size() (got: 3)");
// x out of order.
x = {2, 0, 1};
EXPECT_FALSE(MedianFilter(x, y, 2, 1, 0, 2, &result, &error));
EXPECT_EQ(error,
"The first element of x (got: 2) must be less than the last element"
" (got: 1). Either x is not sorted or all elements are equal.");
// x all equal.
x = {1, 1, 1};
EXPECT_FALSE(MedianFilter(x, y, 2, 1, 0, 2, &result, &error));
EXPECT_EQ(error,
"The first element of x (got: 1) must be less than the last element"
" (got: 1). Either x is not sorted or all elements are equal.");
// x_min not less than x_max
x = {1, 2, 3};
EXPECT_FALSE(MedianFilter(x, y, 2, 1, -1, -1, &result, &error));
EXPECT_EQ(error, "x_min (got: -1) must be less than x_max (got: -1)");
// x_min greater than the last element of x.
x = {1, 2, 3};
EXPECT_FALSE(MedianFilter(x, y, 2, 0.25, 3.5, 4, &result, &error));
EXPECT_EQ(error,
"x_min (got: 3.5) must be less than or equal to the largest value "
"of x (got: 3)");
// bin_width nonpositive.
x = {1, 2, 3};
EXPECT_FALSE(MedianFilter(x, y, 2, 0, 1, 3, &result, &error));
EXPECT_EQ(error, "bin_width must be positive. Got: 0");
// bin_width greater than or equal to x_max - x_min.
x = {1, 2, 3};
EXPECT_FALSE(MedianFilter(x, y, 2, 1, 1.5, 2.5, &result, &error));
EXPECT_EQ(error,
"bin_width (got: 1) must be less than x_max - x_min (got: 1)");
// num_bins less than 2.
x = {1, 2, 3};
EXPECT_FALSE(MedianFilter(x, y, 1, 1, 0, 2, &result, &error));
EXPECT_EQ(error, "num_bins must be greater than 1. Got: 1");
}
TEST(MedianFilter, BucketBoundaries) {
vector<double> x = {-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6};
vector<double> y = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13};
vector<double> result;
std::string error;
EXPECT_TRUE(MedianFilter(x, y, 5, 2, -5, 5, &result, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {2.5, 4.5, 6.5, 8.5, 10.5};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(MedianFilter, MultiSizeBins) {
// Construct bins with size 0, 1, 2, 3, 4, 5, 10, respectively.
vector<double> x = {1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5,
5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6};
vector<double> y = {0, -1, 1, 4, 5, 6, 2, 2, 4, 4, 1, 1, 1,
1, -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
vector<double> result;
std::string error;
EXPECT_TRUE(MedianFilter(x, y, 7, 1, 0, 7, &result, &error));
EXPECT_TRUE(error.empty());
// expected[0] = 3 is the median of y.
vector<double> expected = {3, 0, 0, 5, 3, 1, 5.5};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(MedianFilter, EmptyBins) {
vector<double> x = {-1, 0, 1};
vector<double> y = {2, 3, 1};
vector<double> result;
std::string error;
EXPECT_TRUE(MedianFilter(x, y, 5, 1, -5, 5, &result, &error));
EXPECT_TRUE(error.empty());
// The center bin is the only nonempty bin.
vector<double> expected = {2, 2, 3, 2, 2};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(MedianFilter, WideBins) {
vector<double> x = {-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6};
vector<double> y = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13};
vector<double> result;
std::string error;
EXPECT_TRUE(MedianFilter(x, y, 7, 5, -10, 10, &result, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {1, 2.5, 4, 7, 9, 11.5, 12.5};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(MedianFilter, NarrowBins) {
vector<double> x = {-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6};
vector<double> y = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13};
vector<double> result;
std::string error;
EXPECT_TRUE(MedianFilter(x, y, 9, 0.5, -2.25, 2.25, &result, &error));
EXPECT_TRUE(error.empty());
// Bins 1, 3, 5, 7 are empty.
vector<double> expected = {5, 7, 6, 7, 7, 7, 8, 7, 9};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
} // namespace
} // namespace astronet
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/median.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
using testing::ElementsAreArray;
namespace astronet {
namespace {
TEST(InPlaceMedian, SingleFloat) {
std::vector<double> v = {1.0};
EXPECT_FLOAT_EQ(1.0, InPlaceMedian(v.begin(), v.end()));
EXPECT_THAT(v, ElementsAreArray({1.0}));
}
TEST(InPlaceMedian, TwoInts) {
std::vector<int> v = {3, 2};
// Note that integer division is used, so the median is (2 + 3) / 2 = 2.
EXPECT_EQ(2, InPlaceMedian(v.begin(), v.end()));
EXPECT_THAT(v, ElementsAreArray({2, 3}));
}
TEST(InPlaceMedian, OddElements) {
std::vector<double> v = {1.0, 0.0, 2.0};
EXPECT_FLOAT_EQ(1.0, InPlaceMedian(v.begin(), v.end()));
EXPECT_THAT(v, ElementsAreArray({0.0, 1.0, 2.0}));
}
TEST(InPlaceMedian, EvenElements) {
std::vector<double> v = {1.0, 0.0, 4.0, 3.0};
EXPECT_FLOAT_EQ(2.0, InPlaceMedian(v.begin(), v.end()));
EXPECT_FLOAT_EQ(3.0, v[2]);
EXPECT_FLOAT_EQ(4.0, v[3]);
}
TEST(InPlaceMedian, SubRanges) {
std::vector<double> v = {1.0, 4.0, 0.0, 3.0, -1.0, 6.0, 9.0, -10.0};
// [0, 1)
EXPECT_FLOAT_EQ(1.0, InPlaceMedian(v.begin(), v.begin() + 1));
EXPECT_FLOAT_EQ(1.0, v[0]);
// [1, 4)
EXPECT_FLOAT_EQ(3.0, InPlaceMedian(v.begin() + 1, v.begin() + 4));
EXPECT_FLOAT_EQ(0.0, v[1]);
EXPECT_FLOAT_EQ(3.0, v[2]);
EXPECT_FLOAT_EQ(4.0, v[3]);
// [4, 8)
EXPECT_FLOAT_EQ(2.5, InPlaceMedian(v.begin() + 4, v.end()));
EXPECT_FLOAT_EQ(6.0, v[6]);
EXPECT_FLOAT_EQ(9.0, v[7]);
}
TEST(Median, SingleFloat) {
std::vector<double> v = {-5.0};
EXPECT_FLOAT_EQ(-5.0, Median(v.begin(), v.end()));
EXPECT_THAT(v, ElementsAreArray({-5.0}));
}
TEST(Median, TwoInts) {
std::vector<int> v = {3, 2};
// Note that integer division is used, so the median is (2 + 3) / 2 = 2.
EXPECT_EQ(2, Median(v.begin(), v.end()));
EXPECT_THAT(v, ElementsAreArray({3, 2})); // Unmodified.
}
TEST(Median, SubRanges) {
std::vector<double> v = {1.0, 4.0, 0.0, 3.0, -1.0, 6.0, 9.0, -10.0};
// [0, 1)
EXPECT_FLOAT_EQ(1.0, Median(v.begin(), v.begin() + 1));
EXPECT_THAT(v, ElementsAreArray({1.0, 4.0, 0.0, 3.0, -1.0, 6.0, 9.0, -10.0}));
// [1, 4)
EXPECT_FLOAT_EQ(3.0, Median(v.begin() + 1, v.begin() + 4));
EXPECT_THAT(v, ElementsAreArray({1.0, 4.0, 0.0, 3.0, -1.0, 6.0, 9.0, -10.0}));
// [4, 8)
EXPECT_FLOAT_EQ(2.5, Median(v.begin() + 4, v.end()));
EXPECT_THAT(v, ElementsAreArray({1.0, 4.0, 0.0, 3.0, -1.0, 6.0, 9.0, -10.0}));
}
} // namespace
} // namespace astronet
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/normalize.h"
#include <algorithm>
#include "absl/strings/substitute.h"
#include "light_curve_util/cc/median.h"
using absl::Substitute;
using std::vector;
namespace astronet {
bool NormalizeMedianAndMinimum(const vector<double>& x, vector<double>* result,
std::string* error) {
if (x.size() < 2) {
*error = Substitute("x.size() must be greater than 1. Got: $0", x.size());
return false;
}
// Find the median of x.
vector<double> x_copy(x);
const double median = InPlaceMedian(x_copy.begin(), x_copy.end());
// Find the min element of x. As a post condition of InPlaceMedian, we only
// need to search elements lower than the middle.
const auto x_copy_middle = x_copy.begin() + x_copy.size() / 2;
const auto minimum = std::min_element(x_copy.begin(), x_copy_middle);
// Guaranteed to be positive, unless the median exactly equals the minimum.
double normalizer = median - *minimum;
if (normalizer <= 0) {
*error = Substitute("Minimum and median have the same value: $0", median);
return false;
}
result->resize(x.size());
std::transform(
x.begin(), x.end(), result->begin(),
[median, normalizer](double v) { return (v - median) / normalizer; });
return true;
}
} // namespace astronet
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_NORMALIZE_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_NORMALIZE_H_
#include <iostream>
#include <string>
#include <vector>
namespace astronet {
// Normalizes a vector with an affine transformation such that its median is
// mapped to 0 and its minimum is mapped to -1.
//
// Input args:
// x: Vector to normalize. Must have at least 2 elements and all elements
// cannot be the same value.
//
// Output args:
// result: Output normalized vector. Can be a pointer to the input vector to
// perform the normalization in-place.
// error: String indicating an error (e.g. an invalid argument).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool NormalizeMedianAndMinimum(const std::vector<double>& x,
std::vector<double>* result, std::string* error);
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_NORMALIZE_H_
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/normalize.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h"
using std::vector;
using testing::Pointwise;
namespace astronet {
namespace {
TEST(NormalizeMedianAndMinimum, Error) {
vector<double> x = {-1, -1, -1, -1, -1, -1};
vector<double> result;
std::string error;
EXPECT_FALSE(NormalizeMedianAndMinimum(x, &result, &error));
EXPECT_EQ(error, "Minimum and median have the same value: -1");
}
TEST(NormalizeMedianAndMinimum, TooFewElements) {
vector<double> x = {1};
vector<double> result;
std::string error;
EXPECT_FALSE(NormalizeMedianAndMinimum(x, &result, &error));
EXPECT_EQ(error, "x.size() must be greater than 1. Got: 1");
}
TEST(NormalizeMedianAndMinimum, NonNegative) {
vector<double> x = {0, 1, 2, 3, 4, 5, 6, 7, 8}; // Median 4, Min 0.
vector<double> result;
std::string error;
EXPECT_TRUE(NormalizeMedianAndMinimum(x, &result, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(NormalizeMedianAndMinimum, NonPositive) {
vector<double> x = {0, -1, -2, -3, -4, -5, -6, -7, -8}; // Median -4, Min -8.
vector<double> result;
std::string error;
EXPECT_TRUE(NormalizeMedianAndMinimum(x, &result, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {1, 0.75, 0.5, 0.25, 0, -0.25, -0.5, -0.75, -1};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(NormalizeMedianAndMinimum, PositiveNegative) {
vector<double> x = {-4, -3, -2, -1, 0, 1, 2, 3, 4}; // Median 0, Min -4.
vector<double> result;
std::string error;
EXPECT_TRUE(NormalizeMedianAndMinimum(x, &result, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(NormalizeMedianAndMinimum, InPlace) {
vector<double> x = {-4, -3, -2, -1, 0, 1, 2, 3, 4}; // Median 0, Min -4.
std::string error;
EXPECT_TRUE(NormalizeMedianAndMinimum(x, &x, &error));
EXPECT_TRUE(error.empty());
vector<double> expected = {-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1};
EXPECT_THAT(x, Pointwise(DoubleNear(), expected));
}
} // namespace
} // namespace astronet
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/phase_fold.h"
#include <math.h>
#include <algorithm>
#include <numeric>
#include "absl/strings/substitute.h"
using absl::Substitute;
using std::vector;
namespace astronet {
void PhaseFoldTime(const vector<double>& time, double period, double t0,
vector<double>* result) {
result->resize(time.size());
double half_period = period / 2;
// Compute a constant offset to subtract from each time value before taking
// the remainder modulo the period. This offset ensures that t0 will be
// centered at +/- period / 2 after the remainder operation.
double offset = t0 - half_period;
std::transform(time.begin(), time.end(), result->begin(),
[period, offset, half_period](double t) {
// If t > offset, then rem is in [0, period) with t0 at
// period / 2. Otherwise rem is in (-period, 0] with t0 at
// -period / 2. We shift appropriately to return a value in
// [-period / 2, period / 2) with t0 centered at 0.
double rem = fmod(t - offset, period);
return rem < 0 ? rem + half_period : rem - half_period;
});
}
// Accept time as a value, because we will phase fold in place.
bool PhaseFoldAndSortLightCurve(vector<double> time, const vector<double>& flux,
double period, double t0,
vector<double>* folded_time,
vector<double>* folded_flux,
std::string* error) {
const std::size_t length = time.size();
if (flux.size() != length) {
*error =
Substitute("time.size() (got: $0) must equal flux.size() (got: $1)",
length, flux.size());
return false;
}
// Phase fold time in place.
PhaseFoldTime(time, period, t0, &time);
// Sort the indices of time by ascending value.
vector<std::size_t> sorted_i(length);
std::iota(sorted_i.begin(), sorted_i.end(), 0);
std::sort(
sorted_i.begin(), sorted_i.end(),
[&time](std::size_t i, std::size_t j) { return time[i] < time[j]; });
// Copy phase folded and sorted time and flux into the output.
folded_time->resize(length);
folded_flux->resize(length);
for (int i = 0; i < length; ++i) {
(*folded_time)[i] = time[sorted_i[i]];
(*folded_flux)[i] = flux[sorted_i[i]];
}
return true;
}
} // namespace astronet
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_PHASE_FOLD_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_PHASE_FOLD_H_
#include <iostream>
#include <string>
#include <vector>
namespace astronet {
// Creates a phase-folded time vector.
//
// Specifically, result[i] is the unique number in [-period / 2, period / 2)
// such that result[i] = time[i] - t0 + k_i * period, for some integer k_i.
//
// Input args:
// time: Input vector of time values.
// period: The period to fold over.
// t0: The center of the resulting folded vector; this value is mapped to 0.
//
// Output args:
// result: Output phase folded vector. Can be a pointer to the input vector to
// perform the phase-folding in-place.
void PhaseFoldTime(const std::vector<double>& time, double period, double t0,
std::vector<double>* result);
// Phase folds a light curve and sorts by ascending phase-folded time.
//
// See the comment on PhaseFoldTime for a description of the phase folding
// technique for the time values. The flux values are not modified; they are
// simply permuted to correspond to the sorted phase folded time values.
//
// Input args:
// time: Vector of time values.
// flux: Vector of flux values with the same size as time.
// period: The period to fold over.
// t0: The center of the resulting folded vector; this value is mapped to 0.
//
// Output args:
// folded_time: Output phase folded time values, sorted in ascending order.
// folded_flux: Output flux values corresponding pointwise to folded_time.
// error: String indicating an error (e.g. time and flux are different sizes).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool PhaseFoldAndSortLightCurve(std::vector<double> time,
const std::vector<double>& flux, double period,
double t0, std::vector<double>* folded_time,
std::vector<double>* folded_flux,
std::string* error);
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_PHASE_FOLD_H_
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/phase_fold.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve_util/cc/test_util.h"
using std::vector;
using testing::Pointwise;
namespace astronet {
namespace {
TEST(PhaseFoldTime, Empty) {
vector<double> time = {};
vector<double> result;
PhaseFoldTime(time, 1, 0.45, &result);
EXPECT_TRUE(result.empty());
}
TEST(PhaseFoldTime, Simple) {
vector<double> time = range(0, 2, 0.1);
vector<double> result;
PhaseFoldTime(time, 1, 0.45, &result);
vector<double> expected = {
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45,
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45,
};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(PhaseFoldTime, LargeT0) {
vector<double> time = range(0, 2, 0.1);
vector<double> result;
PhaseFoldTime(time, 1, 1.25, &result);
vector<double> expected = {
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35,
-0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45, -0.35,
};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(PhaseFoldTime, NegativeT0) {
vector<double> time = range(0, 2, 0.1);
vector<double> result;
PhaseFoldTime(time, 1, -1.65, &result);
vector<double> expected = {
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45,
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45,
};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(PhaseFoldTime, NegativeTime) {
vector<double> time = range(-3, -1, 0.1);
vector<double> result;
PhaseFoldTime(time, 1, 0.55, &result);
vector<double> expected = {
0.45, -0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35,
0.45, -0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35,
};
EXPECT_THAT(result, Pointwise(DoubleNear(), expected));
}
TEST(PhaseFoldTime, InPlace) {
vector<double> time = range(0, 2, 0.1);
PhaseFoldTime(time, 0.5, 1.15, &time);
vector<double> expected = {
-0.15, -0.05, 0.05, 0.15, -0.25, -0.15, -0.05, 0.05, 0.15, -0.25,
-0.15, -0.05, 0.05, 0.15, -0.25, -0.15, -0.05, 0.05, 0.15, -0.25,
};
EXPECT_THAT(time, Pointwise(DoubleNear(), time));
}
TEST(PhaseFoldAndSortLightCurve, Error) {
vector<double> time = {1.0, 2.0, 3.0};
vector<double> flux = {7.5, 8.6};
vector<double> folded_time;
vector<double> folded_flux;
std::string error;
EXPECT_FALSE(PhaseFoldAndSortLightCurve(time, flux, 1.0, 0.5, &folded_time,
&folded_flux, &error));
EXPECT_EQ(error, "time.size() (got: 3) must equal flux.size() (got: 2)");
}
TEST(PhaseFoldAndSortLightCurve, Empty) {
vector<double> time = {};
vector<double> flux = {};
vector<double> folded_time;
vector<double> folded_flux;
std::string error;
EXPECT_TRUE(PhaseFoldAndSortLightCurve(time, flux, 1.0, 0.5, &folded_time,
&folded_flux, &error));
EXPECT_TRUE(error.empty());
EXPECT_TRUE(folded_time.empty());
EXPECT_TRUE(folded_flux.empty());
}
TEST(PhaseFoldAndSortLightCurve, FoldAndSort) {
vector<double> time = range(0, 2, 0.1);
vector<double> flux = range(0, 20, 1);
vector<double> folded_time;
vector<double> folded_flux;
std::string error;
EXPECT_TRUE(PhaseFoldAndSortLightCurve(time, flux, 2.0, 0.15, &folded_time,
&folded_flux, &error));
EXPECT_TRUE(error.empty());
vector<double> expected_time = {
-0.95, -0.85, -0.75, -0.65, -0.55, -0.45, -0.35, -0.25, -0.15, -0.05,
0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95};
EXPECT_THAT(folded_time, Pointwise(DoubleNear(), expected_time));
vector<double> expected_flux = {12, 13, 14, 15, 16, 17, 18, 19, 0, 1,
2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
EXPECT_THAT(folded_flux, Pointwise(DoubleNear(), expected_flux));
}
} // namespace
} // namespace astronet
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CLIF Python extension module for median_filter.h.
#
# See https://github.com/google/clif
from light_curve_util.cc.python.postproc import ValueErrorOnFalse
from "third_party/tensorflow_models/astronet/light_curve_util/cc/median_filter.h":
namespace `astronet`:
def `MedianFilter` as median_filter (x: list<float>,
y: list<float>,
num_bins: int,
bin_width: float,
x_min: float,
x_max: float) -> (ok: bool,
result: list<float>,
error: bytes):
return ValueErrorOnFalse(...)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the Python wrapping of the median_filter library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util.cc.python import median_filter
class MedianFilterTest(absltest.TestCase):
def testError(self):
x = [2, 0, 1]
y = [1, 2, 3]
with self.assertRaises(ValueError):
median_filter.median_filter(
x, y, num_bins=2, bin_width=1, x_min=0, x_max=2)
def testMedianFilter(self):
x = np.arange(-6, 7)
y = np.arange(1, 14)
result = median_filter.median_filter(
x, y, num_bins=5, bin_width=2, x_min=-5, x_max=5)
expected = [2.5, 4.5, 6.5, 8.5, 10.5]
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CLIF Python extension module for phase_fold.h.
#
# See https://github.com/google/clif
from light_curve_util.cc.python.postproc import ValueErrorOnFalse
from "third_party/tensorflow_models/astronet/light_curve_util/cc/phase_fold.h":
namespace `astronet`:
def `PhaseFoldTime` as phase_fold_time (time: list<float>,
period: float,
t0: float) -> list<float>
def `PhaseFoldAndSortLightCurve` as phase_fold_and_sort_light_curve (
time: list<float>,
flux: list<float>,
period: float,
t0: float) -> (ok: bool,
folded_time: list<float>,
folded_flux: list<float>,
error: bytes):
return ValueErrorOnFalse(...)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the Python wrapping of the phase_fold library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util.cc.python import phase_fold
class PhaseFoldTimeTest(absltest.TestCase):
def testEmpty(self):
result = phase_fold.phase_fold_time(time=[], period=1, t0=0.45)
self.assertEmpty(result)
def testSimple(self):
time = np.arange(0, 2, 0.1)
result = phase_fold.phase_fold_time(time, period=1, t0=0.45)
expected = [
-0.45, -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, -0.45,
-0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35, 0.45
]
np.testing.assert_almost_equal(result, expected)
class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
def testError(self):
with self.assertRaises(ValueError):
phase_fold.phase_fold_and_sort_light_curve(
time=[1, 2, 3], flux=[7.5, 8.6], period=1, t0=0.5)
def testFoldAndSort(self):
time = np.arange(0, 2, 0.1)
flux = np.arange(0, 20, 1)
folded_time, folded_flux = phase_fold.phase_fold_and_sort_light_curve(
time, flux, period=2, t0=0.15)
expected_time = [
-0.95, -0.85, -0.75, -0.65, -0.55, -0.45, -0.35, -0.25, -0.15, -0.05,
0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95
]
np.testing.assert_almost_equal(folded_time, expected_time)
expected_flux = [
12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
]
np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__':
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Postprocessing utility function for CLIF."""
# CLIF postprocessor for a C++ function with signature:
# bool MyFunc(input_arg1, ..., *output_arg1, *output_arg2, ..., *error)
#
# If MyFunc returns True, returns (output_arg1, output_arg2, ...)
# If MyFunc returns False, raises ValueError(error).
def ValueErrorOnFalse(ok, *output_args):
"""Raises ValueError if not ok, otherwise returns the output arguments."""
n_outputs = len(output_args)
if n_outputs < 2:
raise ValueError("Expected 2 or more output_args. Got: %d" % n_outputs)
if not ok:
error = output_args[-1]
raise ValueError(error)
if n_outputs == 2:
output = output_args[0]
else:
output = output_args[0:-1]
return output
# CLIF postprocessor for a C++ function with signature:
# *result MyFactory(input_arg1, ..., *error)
#
# If result is not null, returns result.
# If result is null, raises ValueError(error).
def ValueErrorOnNull(result, error):
"""Raises ValueError(error) if result is None, otherwise returns result."""
if result is None:
raise ValueError(error)
return result
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CLIF Python extension module for view_generator.h.
#
# See https://github.com/google/clif
from light_curve_util.cc.python.postproc import ValueErrorOnFalse
from light_curve_util.cc.python.postproc import ValueErrorOnNull
from "third_party/tensorflow_models/astronet/light_curve_util/cc/view_generator.h":
namespace `astronet`:
class ViewGenerator:
def `GenerateView` as generate_view (self,
num_bins: int,
bin_width: float,
t_min: float,
t_max: float,
normalize: bool) -> (
ok: bool,
result: list<float>,
error: bytes):
return ValueErrorOnFalse(...)
staticmethods from `ViewGenerator`:
def `Create` as create_view_generator (
time: list<float>,
flux: list<float>,
period: float,
t0: float) -> (vg: ViewGenerator, error: bytes):
return ValueErrorOnNull(...)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the Python wrapping of the view_generator library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from light_curve_util.cc.python import view_generator
class ViewGeneratorTest(absltest.TestCase):
def testPrivateConstructorNotVisible(self):
time = [1, 2, 3]
flux = [2, 3]
with self.assertRaises(ValueError):
view_generator.ViewGenerator(time, flux)
def testCreationError(self):
time = [1, 2, 3]
flux = [2, 3]
with self.assertRaises(ValueError):
view_generator.create_view_generator(time, flux, period=1, t0=0.5)
def testGenerateViews(self):
time = np.arange(0, 2, 0.1)
flux = np.arange(0, 20, 1)
vg = view_generator.create_view_generator(time, flux, period=2, t0=0.15)
with self.assertRaises(ValueError):
vg.generate_view(
num_bins=10, bin_width=0.2, t_min=-1, t_max=-1, normalize=False)
# Global view, unnormalized.
result = vg.generate_view(
num_bins=10, bin_width=0.2, t_min=-1, t_max=1, normalize=False)
expected = [12.5, 14.5, 16.5, 18.5, 0.5, 2.5, 4.5, 6.5, 8.5, 10.5]
np.testing.assert_almost_equal(result, expected)
# Global view, normalized.
result = vg.generate_view(
num_bins=10, bin_width=0.2, t_min=-1, t_max=1, normalize=True)
expected = [
3.0 / 9, 5.0 / 9, 7.0 / 9, 9.0 / 9, -9.0 / 9, -7.0 / 9, -5.0 / 9,
-3.0 / 9, -1.0 / 9, 1.0 / 9
]
np.testing.assert_almost_equal(result, expected)
# Local view, unnormalized.
result = vg.generate_view(
num_bins=5, bin_width=0.2, t_min=-0.5, t_max=0.5, normalize=False)
expected = [17.5, 9.5, 1.5, 3.5, 5.5]
np.testing.assert_almost_equal(result, expected)
# Local view, normalized.
result = vg.generate_view(
num_bins=5, bin_width=0.2, t_min=-0.5, t_max=0.5, normalize=True)
expected = [3, 1, -1, -0.5, 0]
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
absltest.main()
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_TEST_UTIL_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_TEST_UTIL_H_
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace astronet {
// Like testing::DoubleNear, but operates on pairs and can therefore be used in
// testing::Pointwise.
MATCHER(DoubleNear, "") {
return testing::Value(std::get<0>(arg),
testing::DoubleNear(std::get<1>(arg), 1e-12));
}
// Returns the range {start, start + step, start + 2 * step, ...} up to the
// exclusive end value, stop.
inline std::vector<double> range(double start, double stop, double step) {
std::vector<double> result;
while (start < stop) {
result.push_back(start);
start += step;
}
return result;
}
} // namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_UTIL_CC_TEST_UTIL_H_
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve_util/cc/view_generator.h"
#include "absl/memory/memory.h"
#include "light_curve_util/cc/median_filter.h"
#include "light_curve_util/cc/normalize.h"
#include "light_curve_util/cc/phase_fold.h"
using std::vector;
namespace astronet {
// Accept time as a value, because we will phase fold in place.
std::unique_ptr<ViewGenerator> ViewGenerator::Create(const vector<double>& time,
const vector<double>& flux,
double period, double t0,
std::string* error) {
vector<double> folded_time(time.size());
vector<double> folded_flux(flux.size());
if (!PhaseFoldAndSortLightCurve(time, flux, period, t0, &folded_time,
&folded_flux, error)) {
return nullptr;
}
return absl::WrapUnique(
new ViewGenerator(std::move(folded_time), std::move(folded_flux)));
}
bool ViewGenerator::GenerateView(int num_bins, double bin_width, double t_min,
double t_max, bool normalize,
vector<double>* result, std::string* error) {
result->resize(num_bins);
if (!MedianFilter(time_, flux_, num_bins, bin_width, t_min, t_max, result,
error)) {
return false;
}
if (normalize) {
return NormalizeMedianAndMinimum(*result, result, error);
}
return true;
}
ViewGenerator::ViewGenerator(vector<double> time, vector<double> flux)
: time_(std::move(time)), flux_(std::move(flux)) {}
} // namespace astronet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment