Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
42da7864
Commit
42da7864
authored
Dec 04, 2018
by
Christopher Shallue
Browse files
Move tensorflow_models/research/astronet to google-research/exoplanet-ml
parent
17c2f0cc
Changes
130
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
1682 deletions
+0
-1682
research/astronet/light_curve/fast_ops/normalize.cc
research/astronet/light_curve/fast_ops/normalize.cc
+0
-57
research/astronet/light_curve/fast_ops/normalize.h
research/astronet/light_curve/fast_ops/normalize.h
+0
-44
research/astronet/light_curve/fast_ops/normalize_test.cc
research/astronet/light_curve/fast_ops/normalize_test.cc
+0
-93
research/astronet/light_curve/fast_ops/phase_fold.cc
research/astronet/light_curve/fast_ops/phase_fold.cc
+0
-83
research/astronet/light_curve/fast_ops/phase_fold.h
research/astronet/light_curve/fast_ops/phase_fold.h
+0
-68
research/astronet/light_curve/fast_ops/phase_fold_test.cc
research/astronet/light_curve/fast_ops/phase_fold_test.cc
+0
-136
research/astronet/light_curve/fast_ops/python/median_filter.clif
...h/astronet/light_curve/fast_ops/python/median_filter.clif
+0
-31
research/astronet/light_curve/fast_ops/python/median_filter_test.py
...stronet/light_curve/fast_ops/python/median_filter_test.py
+0
-48
research/astronet/light_curve/fast_ops/python/phase_fold.clif
...arch/astronet/light_curve/fast_ops/python/phase_fold.clif
+0
-35
research/astronet/light_curve/fast_ops/python/phase_fold_test.py
...h/astronet/light_curve/fast_ops/python/phase_fold_test.py
+0
-70
research/astronet/light_curve/fast_ops/python/postproc.py
research/astronet/light_curve/fast_ops/python/postproc.py
+0
-52
research/astronet/light_curve/fast_ops/python/view_generator.clif
.../astronet/light_curve/fast_ops/python/view_generator.clif
+0
-42
research/astronet/light_curve/fast_ops/python/view_generator_test.py
...tronet/light_curve/fast_ops/python/view_generator_test.py
+0
-80
research/astronet/light_curve/fast_ops/test_util.h
research/astronet/light_curve/fast_ops/test_util.h
+0
-45
research/astronet/light_curve/fast_ops/view_generator.cc
research/astronet/light_curve/fast_ops/view_generator.cc
+0
-58
research/astronet/light_curve/fast_ops/view_generator.h
research/astronet/light_curve/fast_ops/view_generator.h
+0
-100
research/astronet/light_curve/fast_ops/view_generator_test.cc
...arch/astronet/light_curve/fast_ops/view_generator_test.cc
+0
-87
research/astronet/light_curve/kepler_io.py
research/astronet/light_curve/kepler_io.py
+0
-233
research/astronet/light_curve/kepler_io_test.py
research/astronet/light_curve/kepler_io_test.py
+0
-200
research/astronet/light_curve/median_filter.py
research/astronet/light_curve/median_filter.py
+0
-120
No files found.
research/astronet/light_curve/fast_ops/normalize.cc
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve/fast_ops/normalize.h"
#include <algorithm>
#include "absl/strings/substitute.h"
#include "light_curve/fast_ops/median.h"
using
absl
::
Substitute
;
using
std
::
vector
;
namespace
astronet
{
bool
NormalizeMedianAndMinimum
(
const
vector
<
double
>&
x
,
vector
<
double
>*
result
,
std
::
string
*
error
)
{
if
(
x
.
size
()
<
2
)
{
*
error
=
Substitute
(
"x.size() must be greater than 1. Got: $0"
,
x
.
size
());
return
false
;
}
// Find the median of x.
vector
<
double
>
x_copy
(
x
);
const
double
median
=
InPlaceMedian
(
x_copy
.
begin
(),
x_copy
.
end
());
// Find the min element of x. As a post condition of InPlaceMedian, we only
// need to search elements lower than the middle.
const
auto
x_copy_middle
=
x_copy
.
begin
()
+
x_copy
.
size
()
/
2
;
const
auto
minimum
=
std
::
min_element
(
x_copy
.
begin
(),
x_copy_middle
);
// Guaranteed to be positive, unless the median exactly equals the minimum.
double
normalizer
=
median
-
*
minimum
;
if
(
normalizer
<=
0
)
{
*
error
=
Substitute
(
"Minimum and median have the same value: $0"
,
median
);
return
false
;
}
result
->
resize
(
x
.
size
());
std
::
transform
(
x
.
begin
(),
x
.
end
(),
result
->
begin
(),
[
median
,
normalizer
](
double
v
)
{
return
(
v
-
median
)
/
normalizer
;
});
return
true
;
}
}
// namespace astronet
research/astronet/light_curve/fast_ops/normalize.h
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_NORMALIZE_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_NORMALIZE_H_
#include <iostream>
#include <string>
#include <vector>
namespace
astronet
{
// Normalizes a vector with an affine transformation such that its median is
// mapped to 0 and its minimum is mapped to -1.
//
// Input args:
// x: Vector to normalize. Must have at least 2 elements and all elements
// cannot be the same value.
//
// Output args:
// result: Output normalized vector. Can be a pointer to the input vector to
// perform the normalization in-place.
// error: String indicating an error (e.g. an invalid argument).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool
NormalizeMedianAndMinimum
(
const
std
::
vector
<
double
>&
x
,
std
::
vector
<
double
>*
result
,
std
::
string
*
error
);
}
// namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_NORMALIZE_H_
research/astronet/light_curve/fast_ops/normalize_test.cc
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve/fast_ops/normalize.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve/fast_ops/test_util.h"
using
std
::
vector
;
using
testing
::
Pointwise
;
namespace
astronet
{
namespace
{
TEST
(
NormalizeMedianAndMinimum
,
Error
)
{
vector
<
double
>
x
=
{
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
};
vector
<
double
>
result
;
std
::
string
error
;
EXPECT_FALSE
(
NormalizeMedianAndMinimum
(
x
,
&
result
,
&
error
));
EXPECT_EQ
(
error
,
"Minimum and median have the same value: -1"
);
}
TEST
(
NormalizeMedianAndMinimum
,
TooFewElements
)
{
vector
<
double
>
x
=
{
1
};
vector
<
double
>
result
;
std
::
string
error
;
EXPECT_FALSE
(
NormalizeMedianAndMinimum
(
x
,
&
result
,
&
error
));
EXPECT_EQ
(
error
,
"x.size() must be greater than 1. Got: 1"
);
}
TEST
(
NormalizeMedianAndMinimum
,
NonNegative
)
{
vector
<
double
>
x
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
// Median 4, Min 0.
vector
<
double
>
result
;
std
::
string
error
;
EXPECT_TRUE
(
NormalizeMedianAndMinimum
(
x
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
expected
=
{
-
1
,
-
0.75
,
-
0.5
,
-
0.25
,
0
,
0.25
,
0.5
,
0.75
,
1
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
NormalizeMedianAndMinimum
,
NonPositive
)
{
vector
<
double
>
x
=
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
-
6
,
-
7
,
-
8
};
// Median -4, Min -8.
vector
<
double
>
result
;
std
::
string
error
;
EXPECT_TRUE
(
NormalizeMedianAndMinimum
(
x
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
expected
=
{
1
,
0.75
,
0.5
,
0.25
,
0
,
-
0.25
,
-
0.5
,
-
0.75
,
-
1
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
NormalizeMedianAndMinimum
,
PositiveNegative
)
{
vector
<
double
>
x
=
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
,
4
};
// Median 0, Min -4.
vector
<
double
>
result
;
std
::
string
error
;
EXPECT_TRUE
(
NormalizeMedianAndMinimum
(
x
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
expected
=
{
-
1
,
-
0.75
,
-
0.5
,
-
0.25
,
0
,
0.25
,
0.5
,
0.75
,
1
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
NormalizeMedianAndMinimum
,
InPlace
)
{
vector
<
double
>
x
=
{
-
4
,
-
3
,
-
2
,
-
1
,
0
,
1
,
2
,
3
,
4
};
// Median 0, Min -4.
std
::
string
error
;
EXPECT_TRUE
(
NormalizeMedianAndMinimum
(
x
,
&
x
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
expected
=
{
-
1
,
-
0.75
,
-
0.5
,
-
0.25
,
0
,
0.25
,
0.5
,
0.75
,
1
};
EXPECT_THAT
(
x
,
Pointwise
(
DoubleNear
(),
expected
));
}
}
// namespace
}
// namespace astronet
research/astronet/light_curve/fast_ops/phase_fold.cc
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve/fast_ops/phase_fold.h"
#include <math.h>
#include <algorithm>
#include <numeric>
#include "absl/strings/substitute.h"
using
absl
::
Substitute
;
using
std
::
vector
;
namespace
astronet
{
void
PhaseFoldTime
(
const
vector
<
double
>&
time
,
double
period
,
double
t0
,
vector
<
double
>*
result
)
{
result
->
resize
(
time
.
size
());
double
half_period
=
period
/
2
;
// Compute a constant offset to subtract from each time value before taking
// the remainder modulo the period. This offset ensures that t0 will be
// centered at +/- period / 2 after the remainder operation.
double
offset
=
t0
-
half_period
;
std
::
transform
(
time
.
begin
(),
time
.
end
(),
result
->
begin
(),
[
period
,
offset
,
half_period
](
double
t
)
{
// If t > offset, then rem is in [0, period) with t0 at
// period / 2. Otherwise rem is in (-period, 0] with t0 at
// -period / 2. We shift appropriately to return a value in
// [-period / 2, period / 2) with t0 centered at 0.
double
rem
=
fmod
(
t
-
offset
,
period
);
return
rem
<
0
?
rem
+
half_period
:
rem
-
half_period
;
});
}
// Accept time as a value, because we will phase fold in place.
bool
PhaseFoldAndSortLightCurve
(
vector
<
double
>
time
,
const
vector
<
double
>&
flux
,
double
period
,
double
t0
,
vector
<
double
>*
folded_time
,
vector
<
double
>*
folded_flux
,
std
::
string
*
error
)
{
const
std
::
size_t
length
=
time
.
size
();
if
(
flux
.
size
()
!=
length
)
{
*
error
=
Substitute
(
"time.size() (got: $0) must equal flux.size() (got: $1)"
,
length
,
flux
.
size
());
return
false
;
}
// Phase fold time in place.
PhaseFoldTime
(
time
,
period
,
t0
,
&
time
);
// Sort the indices of time by ascending value.
vector
<
std
::
size_t
>
sorted_i
(
length
);
std
::
iota
(
sorted_i
.
begin
(),
sorted_i
.
end
(),
0
);
std
::
sort
(
sorted_i
.
begin
(),
sorted_i
.
end
(),
[
&
time
](
std
::
size_t
i
,
std
::
size_t
j
)
{
return
time
[
i
]
<
time
[
j
];
});
// Copy phase folded and sorted time and flux into the output.
folded_time
->
resize
(
length
);
folded_flux
->
resize
(
length
);
for
(
int
i
=
0
;
i
<
length
;
++
i
)
{
(
*
folded_time
)[
i
]
=
time
[
sorted_i
[
i
]];
(
*
folded_flux
)[
i
]
=
flux
[
sorted_i
[
i
]];
}
return
true
;
}
}
// namespace astronet
research/astronet/light_curve/fast_ops/phase_fold.h
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_PHASE_FOLD_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_PHASE_FOLD_H_
#include <iostream>
#include <string>
#include <vector>
namespace
astronet
{
// Creates a phase-folded time vector.
//
// Specifically, result[i] is the unique number in [-period / 2, period / 2)
// such that result[i] = time[i] - t0 + k_i * period, for some integer k_i.
//
// Input args:
// time: Input vector of time values.
// period: The period to fold over.
// t0: The center of the resulting folded vector; this value is mapped to 0.
//
// Output args:
// result: Output phase folded vector. Can be a pointer to the input vector to
// perform the phase-folding in-place.
void
PhaseFoldTime
(
const
std
::
vector
<
double
>&
time
,
double
period
,
double
t0
,
std
::
vector
<
double
>*
result
);
// Phase folds a light curve and sorts by ascending phase-folded time.
//
// See the comment on PhaseFoldTime for a description of the phase folding
// technique for the time values. The flux values are not modified; they are
// simply permuted to correspond to the sorted phase folded time values.
//
// Input args:
// time: Vector of time values.
// flux: Vector of flux values with the same size as time.
// period: The period to fold over.
// t0: The center of the resulting folded vector; this value is mapped to 0.
//
// Output args:
// folded_time: Output phase folded time values, sorted in ascending order.
// folded_flux: Output flux values corresponding pointwise to folded_time.
// error: String indicating an error (e.g. time and flux are different sizes).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool
PhaseFoldAndSortLightCurve
(
std
::
vector
<
double
>
time
,
const
std
::
vector
<
double
>&
flux
,
double
period
,
double
t0
,
std
::
vector
<
double
>*
folded_time
,
std
::
vector
<
double
>*
folded_flux
,
std
::
string
*
error
);
}
// namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_PHASE_FOLD_H_
research/astronet/light_curve/fast_ops/phase_fold_test.cc
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve/fast_ops/phase_fold.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve/fast_ops/test_util.h"
using
std
::
vector
;
using
testing
::
Pointwise
;
namespace
astronet
{
namespace
{
TEST
(
PhaseFoldTime
,
Empty
)
{
vector
<
double
>
time
=
{};
vector
<
double
>
result
;
PhaseFoldTime
(
time
,
1
,
0.45
,
&
result
);
EXPECT_TRUE
(
result
.
empty
());
}
TEST
(
PhaseFoldTime
,
Simple
)
{
vector
<
double
>
time
=
range
(
0
,
2
,
0.1
);
vector
<
double
>
result
;
PhaseFoldTime
(
time
,
1
,
0.45
,
&
result
);
vector
<
double
>
expected
=
{
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
PhaseFoldTime
,
LargeT0
)
{
vector
<
double
>
time
=
range
(
0
,
2
,
0.1
);
vector
<
double
>
result
;
PhaseFoldTime
(
time
,
1
,
1.25
,
&
result
);
vector
<
double
>
expected
=
{
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
-
0.35
,
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
PhaseFoldTime
,
NegativeT0
)
{
vector
<
double
>
time
=
range
(
0
,
2
,
0.1
);
vector
<
double
>
result
;
PhaseFoldTime
(
time
,
1
,
-
1.65
,
&
result
);
vector
<
double
>
expected
=
{
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
PhaseFoldTime
,
NegativeTime
)
{
vector
<
double
>
time
=
range
(
-
3
,
-
1
,
0.1
);
vector
<
double
>
result
;
PhaseFoldTime
(
time
,
1
,
0.55
,
&
result
);
vector
<
double
>
expected
=
{
0.45
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
TEST
(
PhaseFoldTime
,
InPlace
)
{
vector
<
double
>
time
=
range
(
0
,
2
,
0.1
);
PhaseFoldTime
(
time
,
0.5
,
1.15
,
&
time
);
vector
<
double
>
expected
=
{
-
0.15
,
-
0.05
,
0.05
,
0.15
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
-
0.25
,
};
EXPECT_THAT
(
time
,
Pointwise
(
DoubleNear
(),
time
));
}
TEST
(
PhaseFoldAndSortLightCurve
,
Error
)
{
vector
<
double
>
time
=
{
1.0
,
2.0
,
3.0
};
vector
<
double
>
flux
=
{
7.5
,
8.6
};
vector
<
double
>
folded_time
;
vector
<
double
>
folded_flux
;
std
::
string
error
;
EXPECT_FALSE
(
PhaseFoldAndSortLightCurve
(
time
,
flux
,
1.0
,
0.5
,
&
folded_time
,
&
folded_flux
,
&
error
));
EXPECT_EQ
(
error
,
"time.size() (got: 3) must equal flux.size() (got: 2)"
);
}
TEST
(
PhaseFoldAndSortLightCurve
,
Empty
)
{
vector
<
double
>
time
=
{};
vector
<
double
>
flux
=
{};
vector
<
double
>
folded_time
;
vector
<
double
>
folded_flux
;
std
::
string
error
;
EXPECT_TRUE
(
PhaseFoldAndSortLightCurve
(
time
,
flux
,
1.0
,
0.5
,
&
folded_time
,
&
folded_flux
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
EXPECT_TRUE
(
folded_time
.
empty
());
EXPECT_TRUE
(
folded_flux
.
empty
());
}
TEST
(
PhaseFoldAndSortLightCurve
,
FoldAndSort
)
{
vector
<
double
>
time
=
range
(
0
,
2
,
0.1
);
vector
<
double
>
flux
=
range
(
0
,
20
,
1
);
vector
<
double
>
folded_time
;
vector
<
double
>
folded_flux
;
std
::
string
error
;
EXPECT_TRUE
(
PhaseFoldAndSortLightCurve
(
time
,
flux
,
2.0
,
0.15
,
&
folded_time
,
&
folded_flux
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
expected_time
=
{
-
0.95
,
-
0.85
,
-
0.75
,
-
0.65
,
-
0.55
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
0.55
,
0.65
,
0.75
,
0.85
,
0.95
};
EXPECT_THAT
(
folded_time
,
Pointwise
(
DoubleNear
(),
expected_time
));
vector
<
double
>
expected_flux
=
{
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
EXPECT_THAT
(
folded_flux
,
Pointwise
(
DoubleNear
(),
expected_flux
));
}
}
// namespace
}
// namespace astronet
research/astronet/light_curve/fast_ops/python/median_filter.clif
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CLIF Python extension module for median_filter.h.
#
# See https://github.com/google/clif
from light_curve.fast_ops.python.postproc import ValueErrorOnFalse
from "third_party/tensorflow_models/astronet/light_curve/fast_ops/median_filter.h":
namespace `astronet`:
def `MedianFilter` as median_filter (x: list<float>,
y: list<float>,
num_bins: int,
bin_width: float,
x_min: float,
x_max: float) -> (ok: bool,
result: list<float>,
error: bytes):
return ValueErrorOnFalse(...)
research/astronet/light_curve/fast_ops/python/median_filter_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the Python wrapping of the median_filter library."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
import
numpy
as
np
from
light_curve.fast_ops.python
import
median_filter
class
MedianFilterTest
(
absltest
.
TestCase
):
def
testError
(
self
):
x
=
[
2
,
0
,
1
]
y
=
[
1
,
2
,
3
]
with
self
.
assertRaises
(
ValueError
):
median_filter
.
median_filter
(
x
,
y
,
num_bins
=
2
,
bin_width
=
1
,
x_min
=
0
,
x_max
=
2
)
def
testMedianFilter
(
self
):
x
=
np
.
arange
(
-
6
,
7
)
y
=
np
.
arange
(
1
,
14
)
result
=
median_filter
.
median_filter
(
x
,
y
,
num_bins
=
5
,
bin_width
=
2
,
x_min
=-
5
,
x_max
=
5
)
expected
=
[
2.5
,
4.5
,
6.5
,
8.5
,
10.5
]
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
if
__name__
==
"__main__"
:
absltest
.
main
()
research/astronet/light_curve/fast_ops/python/phase_fold.clif
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CLIF Python extension module for phase_fold.h.
#
# See https://github.com/google/clif
from light_curve.fast_ops.python.postproc import ValueErrorOnFalse
from "third_party/tensorflow_models/astronet/light_curve/fast_ops/phase_fold.h":
namespace `astronet`:
def `PhaseFoldTime` as phase_fold_time (time: list<float>,
period: float,
t0: float) -> list<float>
def `PhaseFoldAndSortLightCurve` as phase_fold_and_sort_light_curve (
time: list<float>,
flux: list<float>,
period: float,
t0: float) -> (ok: bool,
folded_time: list<float>,
folded_flux: list<float>,
error: bytes):
return ValueErrorOnFalse(...)
research/astronet/light_curve/fast_ops/python/phase_fold_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the Python wrapping of the phase_fold library."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
import
numpy
as
np
from
light_curve.fast_ops.python
import
phase_fold
class
PhaseFoldTimeTest
(
absltest
.
TestCase
):
def
testEmpty
(
self
):
result
=
phase_fold
.
phase_fold_time
(
time
=
[],
period
=
1
,
t0
=
0.45
)
self
.
assertEmpty
(
result
)
def
testSimple
(
self
):
time
=
np
.
arange
(
0
,
2
,
0.1
)
result
=
phase_fold
.
phase_fold_time
(
time
,
period
=
1
,
t0
=
0.45
)
expected
=
[
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
]
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
class
PhaseFoldAndSortLightCurveTest
(
absltest
.
TestCase
):
def
testError
(
self
):
with
self
.
assertRaises
(
ValueError
):
phase_fold
.
phase_fold_and_sort_light_curve
(
time
=
[
1
,
2
,
3
],
flux
=
[
7.5
,
8.6
],
period
=
1
,
t0
=
0.5
)
def
testFoldAndSort
(
self
):
time
=
np
.
arange
(
0
,
2
,
0.1
)
flux
=
np
.
arange
(
0
,
20
,
1
)
folded_time
,
folded_flux
=
phase_fold
.
phase_fold_and_sort_light_curve
(
time
,
flux
,
period
=
2
,
t0
=
0.15
)
expected_time
=
[
-
0.95
,
-
0.85
,
-
0.75
,
-
0.65
,
-
0.55
,
-
0.45
,
-
0.35
,
-
0.25
,
-
0.15
,
-
0.05
,
0.05
,
0.15
,
0.25
,
0.35
,
0.45
,
0.55
,
0.65
,
0.75
,
0.85
,
0.95
]
np
.
testing
.
assert_almost_equal
(
folded_time
,
expected_time
)
expected_flux
=
[
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
]
np
.
testing
.
assert_almost_equal
(
folded_flux
,
expected_flux
)
if
__name__
==
"__main__"
:
absltest
.
main
()
research/astronet/light_curve/fast_ops/python/postproc.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Postprocessing utility functions for CLIF."""
# CLIF postprocessor for a C++ function with signature:
# bool MyFunc(input_arg1, ..., *output_arg1, *output_arg2, ..., *error)
#
# If MyFunc returns True, returns (output_arg1, output_arg2, ...)
# If MyFunc returns False, raises ValueError(error).
def
ValueErrorOnFalse
(
ok
,
*
output_args
):
"""Raises ValueError if not ok, otherwise returns the output arguments."""
n_outputs
=
len
(
output_args
)
if
n_outputs
<
2
:
raise
ValueError
(
"Expected 2 or more output_args. Got: {}"
.
format
(
n_outputs
))
if
not
ok
:
error
=
output_args
[
-
1
]
raise
ValueError
(
error
)
if
n_outputs
==
2
:
output
=
output_args
[
0
]
else
:
output
=
output_args
[
0
:
-
1
]
return
output
# CLIF postprocessor for a C++ function with signature:
# *result MyFactory(input_arg1, ..., *error)
#
# If result is not null, returns result.
# If result is null, raises ValueError(error).
def
ValueErrorOnNull
(
result
,
error
):
"""Raises ValueError(error) if result is None, otherwise returns result."""
if
result
is
None
:
raise
ValueError
(
error
)
return
result
research/astronet/light_curve/fast_ops/python/view_generator.clif
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CLIF Python extension module for view_generator.h.
#
# See https://github.com/google/clif
from light_curve.fast_ops.python.postproc import ValueErrorOnFalse
from light_curve.fast_ops.python.postproc import ValueErrorOnNull
from "third_party/tensorflow_models/astronet/light_curve/fast_ops/view_generator.h":
namespace `astronet`:
class ViewGenerator:
def `GenerateView` as generate_view (self,
num_bins: int,
bin_width: float,
t_min: float,
t_max: float,
normalize: bool) -> (
ok: bool,
result: list<float>,
error: bytes):
return ValueErrorOnFalse(...)
staticmethods from `ViewGenerator`:
def `Create` as create_view_generator (
time: list<float>,
flux: list<float>,
period: float,
t0: float) -> (vg: ViewGenerator, error: bytes):
return ValueErrorOnNull(...)
research/astronet/light_curve/fast_ops/python/view_generator_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests the Python wrapping of the view_generator library."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
import
numpy
as
np
from
light_curve.fast_ops.python
import
view_generator
class
ViewGeneratorTest
(
absltest
.
TestCase
):
def
testPrivateConstructorNotVisible
(
self
):
time
=
[
1
,
2
,
3
]
flux
=
[
2
,
3
]
with
self
.
assertRaises
(
ValueError
):
view_generator
.
ViewGenerator
(
time
,
flux
)
def
testCreationError
(
self
):
time
=
[
1
,
2
,
3
]
flux
=
[
2
,
3
]
with
self
.
assertRaises
(
ValueError
):
view_generator
.
create_view_generator
(
time
,
flux
,
period
=
1
,
t0
=
0.5
)
def
testGenerateViews
(
self
):
time
=
np
.
arange
(
0
,
2
,
0.1
)
flux
=
np
.
arange
(
0
,
20
,
1
)
vg
=
view_generator
.
create_view_generator
(
time
,
flux
,
period
=
2
,
t0
=
0.15
)
with
self
.
assertRaises
(
ValueError
):
vg
.
generate_view
(
num_bins
=
10
,
bin_width
=
0.2
,
t_min
=-
1
,
t_max
=-
1
,
normalize
=
False
)
# Global view, unnormalized.
result
=
vg
.
generate_view
(
num_bins
=
10
,
bin_width
=
0.2
,
t_min
=-
1
,
t_max
=
1
,
normalize
=
False
)
expected
=
[
12.5
,
14.5
,
16.5
,
18.5
,
0.5
,
2.5
,
4.5
,
6.5
,
8.5
,
10.5
]
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
# Global view, normalized.
result
=
vg
.
generate_view
(
num_bins
=
10
,
bin_width
=
0.2
,
t_min
=-
1
,
t_max
=
1
,
normalize
=
True
)
expected
=
[
3.0
/
9
,
5.0
/
9
,
7.0
/
9
,
9.0
/
9
,
-
9.0
/
9
,
-
7.0
/
9
,
-
5.0
/
9
,
-
3.0
/
9
,
-
1.0
/
9
,
1.0
/
9
]
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
# Local view, unnormalized.
result
=
vg
.
generate_view
(
num_bins
=
5
,
bin_width
=
0.2
,
t_min
=-
0.5
,
t_max
=
0.5
,
normalize
=
False
)
expected
=
[
17.5
,
9.5
,
1.5
,
3.5
,
5.5
]
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
# Local view, normalized.
result
=
vg
.
generate_view
(
num_bins
=
5
,
bin_width
=
0.2
,
t_min
=-
0.5
,
t_max
=
0.5
,
normalize
=
True
)
expected
=
[
3
,
1
,
-
1
,
-
0.5
,
0
]
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
if
__name__
==
"__main__"
:
absltest
.
main
()
research/astronet/light_curve/fast_ops/test_util.h
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_TEST_UTIL_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_TEST_UTIL_H_
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace
astronet
{
// Like testing::DoubleNear, but operates on pairs and can therefore be used in
// testing::Pointwise.
MATCHER
(
DoubleNear
,
""
)
{
return
testing
::
Value
(
std
::
get
<
0
>
(
arg
),
testing
::
DoubleNear
(
std
::
get
<
1
>
(
arg
),
1e-12
));
}
// Returns the range {start, start + step, start + 2 * step, ...} up to the
// exclusive end value, stop.
inline
std
::
vector
<
double
>
range
(
double
start
,
double
stop
,
double
step
)
{
std
::
vector
<
double
>
result
;
while
(
start
<
stop
)
{
result
.
push_back
(
start
);
start
+=
step
;
}
return
result
;
}
}
// namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_TEST_UTIL_H_
research/astronet/light_curve/fast_ops/view_generator.cc
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve/fast_ops/view_generator.h"
#include "absl/memory/memory.h"
#include "light_curve/fast_ops/median_filter.h"
#include "light_curve/fast_ops/normalize.h"
#include "light_curve/fast_ops/phase_fold.h"
using
std
::
vector
;
namespace
astronet
{
// Accept time as a value, because we will phase fold in place.
std
::
unique_ptr
<
ViewGenerator
>
ViewGenerator
::
Create
(
const
vector
<
double
>&
time
,
const
vector
<
double
>&
flux
,
double
period
,
double
t0
,
std
::
string
*
error
)
{
vector
<
double
>
folded_time
(
time
.
size
());
vector
<
double
>
folded_flux
(
flux
.
size
());
if
(
!
PhaseFoldAndSortLightCurve
(
time
,
flux
,
period
,
t0
,
&
folded_time
,
&
folded_flux
,
error
))
{
return
nullptr
;
}
return
absl
::
WrapUnique
(
new
ViewGenerator
(
std
::
move
(
folded_time
),
std
::
move
(
folded_flux
)));
}
bool
ViewGenerator
::
GenerateView
(
int
num_bins
,
double
bin_width
,
double
t_min
,
double
t_max
,
bool
normalize
,
vector
<
double
>*
result
,
std
::
string
*
error
)
{
result
->
resize
(
num_bins
);
if
(
!
MedianFilter
(
time_
,
flux_
,
num_bins
,
bin_width
,
t_min
,
t_max
,
result
,
error
))
{
return
false
;
}
if
(
normalize
)
{
return
NormalizeMedianAndMinimum
(
*
result
,
result
,
error
);
}
return
true
;
}
ViewGenerator
::
ViewGenerator
(
vector
<
double
>
time
,
vector
<
double
>
flux
)
:
time_
(
std
::
move
(
time
)),
flux_
(
std
::
move
(
flux
))
{}
}
// namespace astronet
research/astronet/light_curve/fast_ops/view_generator.h
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_VIEW_GENERATOR_H_
#define TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_VIEW_GENERATOR_H_
#include <memory>
#include <string>
#include <vector>
namespace
astronet
{
// Helper class for phase-folding a light curve and then generating "views" of
// the light curve using a median filter.
//
// This class wraps functions for phase folding, median filtering, and
// normalizing for efficient use as a Python extension. It keeps the
// phase-folded light curve in the class state to minimize expensive copies
// between the language barrier.
class
ViewGenerator
{
public:
// Factory function to create a new ViewGenerator.
//
// Input args:
// time: Vector of time values, not phase-folded.
// flux: Vector of flux values with the same size as time.
// period: The period to fold over.
// t0: The center of the resulting folded vector; this value is mapped to 0.
//
// Output args:
// error: String indicating an error (e.g. time and flux are different
// sizes).
//
// Returns:
// A ViewGenerator. May be a nullptr in the case of an error; see the
// "error" string if so.
static
std
::
unique_ptr
<
ViewGenerator
>
Create
(
const
std
::
vector
<
double
>&
time
,
const
std
::
vector
<
double
>&
flux
,
double
period
,
double
t0
,
std
::
string
*
error
);
// Generates a "view" of the phase-folded light curve using a median filter.
//
// Note that the time values of the phase-folded light curve are in the range
// [-period / 2, period / 2).
//
// This function applies astronet::MedianFilter() to the phase-folded and
// sorted light curve, followed optionally by
// astronet::NormalizeMedianAndMinimum(). See the comments on those
// functions for more details.
//
// Input args:
// num_bins: The number of intervals to divide the time axis into. Must be
// at least 2.
// bin_width: The width of each bin on the time axis. Must be positive, and
// less than t_max - t_min.
// t_min: The inclusive leftmost value to consider on the time axis. This
// should probably be at least -period / 2, which is the minimum
// possible value of the phase-folded light curve. Must be less than the
// largest value of the phase-folded time axis.
// t_max: The exclusive rightmost value to consider on the time axis. This
// should probably be at most period / 2, which is the maximum possible
// value of the phase-folded light curve. Must be greater than t_min.
// normalize: Whether to normalize the output vector to have median 0 and
// minimum -1.
//
// Output args:
// result: Vector of size num_bins containing the median flux values of
// uniformly spaced bins on the phase-folded time axis.
// error: String indicating an error (e.g. an invalid argument).
//
// Returns:
// true if the algorithm succeeded. If false, see "error".
bool
GenerateView
(
int
num_bins
,
double
bin_width
,
double
t_min
,
double
t_max
,
bool
normalize
,
std
::
vector
<
double
>*
result
,
std
::
string
*
error
);
protected:
// This class can only be constructed by Create().
ViewGenerator
(
std
::
vector
<
double
>
time
,
std
::
vector
<
double
>
flux
);
// phase-folded light curve, sorted by time in ascending order.
std
::
vector
<
double
>
time_
;
std
::
vector
<
double
>
flux_
;
};
}
// namespace astronet
#endif // TENSORFLOW_MODELS_ASTRONET_LIGHT_CURVE_FAST_OPS_VIEW_GENERATOR_H_
research/astronet/light_curve/fast_ops/view_generator_test.cc
deleted
100644 → 0
View file @
17c2f0cc
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "light_curve/fast_ops/view_generator.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "light_curve/fast_ops/test_util.h"
using
std
::
vector
;
using
testing
::
Pointwise
;
namespace
astronet
{
namespace
{
TEST
(
ViewGenerator
,
CreationError
)
{
vector
<
double
>
time
=
{
1
,
2
,
3
};
vector
<
double
>
flux
=
{
2
,
3
};
std
::
string
error
;
std
::
unique_ptr
<
ViewGenerator
>
generator
=
ViewGenerator
::
Create
(
time
,
flux
,
1
,
0.5
,
&
error
);
EXPECT_EQ
(
nullptr
,
generator
);
EXPECT_FALSE
(
error
.
empty
());
}
TEST
(
ViewGenerator
,
GenerateViews
)
{
vector
<
double
>
time
=
range
(
0
,
2
,
0.1
);
vector
<
double
>
flux
=
range
(
0
,
20
,
1
);
std
::
string
error
;
// Create the ViewGenerator.
std
::
unique_ptr
<
ViewGenerator
>
generator
=
ViewGenerator
::
Create
(
time
,
flux
,
2.0
,
0.15
,
&
error
);
EXPECT_NE
(
nullptr
,
generator
);
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
result
;
// Error: t_max <= t_min. We do not test all failure cases here since they
// are covered by the median filter's tests.
EXPECT_FALSE
(
generator
->
GenerateView
(
10
,
1
,
-
1
,
-
1
,
false
,
&
result
,
&
error
));
EXPECT_FALSE
(
error
.
empty
());
error
.
clear
();
// Global view, unnormalized.
EXPECT_TRUE
(
generator
->
GenerateView
(
10
,
0.2
,
-
1
,
1
,
false
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
vector
<
double
>
expected
=
{
12.5
,
14.5
,
16.5
,
18.5
,
0.5
,
2.5
,
4.5
,
6.5
,
8.5
,
10.5
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
// Global view, normalized.
EXPECT_TRUE
(
generator
->
GenerateView
(
10
,
0.2
,
-
1
,
1
,
true
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
expected
=
{
3.0
/
9
,
5.0
/
9
,
7.0
/
9
,
9.0
/
9
,
-
9.0
/
9
,
-
7.0
/
9
,
-
5.0
/
9
,
-
3.0
/
9
,
-
1.0
/
9
,
1.0
/
9
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
// Local view, unnormalized.
EXPECT_TRUE
(
generator
->
GenerateView
(
5
,
0.2
,
-
0.5
,
0.5
,
false
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
expected
=
{
17.5
,
9.5
,
1.5
,
3.5
,
5.5
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
// Local view, normalized.
EXPECT_TRUE
(
generator
->
GenerateView
(
5
,
0.2
,
-
0.5
,
0.5
,
true
,
&
result
,
&
error
));
EXPECT_TRUE
(
error
.
empty
());
expected
=
{
3
,
1
,
-
1
,
-
0.5
,
0
};
EXPECT_THAT
(
result
,
Pointwise
(
DoubleNear
(),
expected
));
}
}
// namespace
}
// namespace astronet
research/astronet/light_curve/kepler_io.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for reading Kepler data."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os.path
from
astropy.io
import
fits
import
numpy
as
np
from
light_curve
import
util
from
tensorflow
import
gfile
# Quarter index to filename prefix for long cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
LONG_CADENCE_QUARTER_PREFIXES
=
{
0
:
[
"2009131105131"
],
1
:
[
"2009166043257"
],
2
:
[
"2009259160929"
],
3
:
[
"2009350155506"
],
4
:
[
"2010078095331"
,
"2010009091648"
],
5
:
[
"2010174085026"
],
6
:
[
"2010265121752"
],
7
:
[
"2010355172524"
],
8
:
[
"2011073133259"
],
9
:
[
"2011177032512"
],
10
:
[
"2011271113734"
],
11
:
[
"2012004120508"
],
12
:
[
"2012088054726"
],
13
:
[
"2012179063303"
],
14
:
[
"2012277125453"
],
15
:
[
"2013011073258"
],
16
:
[
"2013098041711"
],
17
:
[
"2013131215648"
]
}
# Quarter index to filename prefix for short cadence Kepler data.
# Reference: https://archive.stsci.edu/kepler/software/get_kepler.py
SHORT_CADENCE_QUARTER_PREFIXES
=
{
0
:
[
"2009131110544"
],
1
:
[
"2009166044711"
],
2
:
[
"2009201121230"
,
"2009231120729"
,
"2009259162342"
],
3
:
[
"2009291181958"
,
"2009322144938"
,
"2009350160919"
],
4
:
[
"2010009094841"
,
"2010019161129"
,
"2010049094358"
,
"2010078100744"
],
5
:
[
"2010111051353"
,
"2010140023957"
,
"2010174090439"
],
6
:
[
"2010203174610"
,
"2010234115140"
,
"2010265121752"
],
7
:
[
"2010296114515"
,
"2010326094124"
,
"2010355172524"
],
8
:
[
"2011024051157"
,
"2011053090032"
,
"2011073133259"
],
9
:
[
"2011116030358"
,
"2011145075126"
,
"2011177032512"
],
10
:
[
"2011208035123"
,
"2011240104155"
,
"2011271113734"
],
11
:
[
"2011303113607"
,
"2011334093404"
,
"2012004120508"
],
12
:
[
"2012032013838"
,
"2012060035710"
,
"2012088054726"
],
13
:
[
"2012121044856"
,
"2012151031540"
,
"2012179063303"
],
14
:
[
"2012211050319"
,
"2012242122129"
,
"2012277125453"
],
15
:
[
"2012310112549"
,
"2012341132017"
,
"2013011073258"
],
16
:
[
"2013017113907"
,
"2013065031647"
,
"2013098041711"
],
17
:
[
"2013121191144"
,
"2013131215648"
]
}
# Quarter order for different scrambling procedures.
# Page 9: https://ntrs.nasa.gov/archive/nasa/casi.ntrs.nasa.gov/20170009549.pdf.
SIMULATED_DATA_SCRAMBLE_ORDERS
=
{
"SCR1"
:
[
0
,
13
,
14
,
15
,
16
,
9
,
10
,
11
,
12
,
5
,
6
,
7
,
8
,
1
,
2
,
3
,
4
,
17
],
"SCR2"
:
[
0
,
1
,
2
,
3
,
4
,
13
,
14
,
15
,
16
,
9
,
10
,
11
,
12
,
5
,
6
,
7
,
8
,
17
],
"SCR3"
:
[
0
,
16
,
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
17
],
}
def
kepler_filenames
(
base_dir
,
kep_id
,
long_cadence
=
True
,
quarters
=
None
,
injected_group
=
None
,
check_existence
=
True
):
"""Returns the light curve filenames for a Kepler target star.
This function assumes the directory structure of the Mikulski Archive for
Space Telescopes (http://archive.stsci.edu/pub/kepler/lightcurves).
Specifically, the filenames for a particular Kepler target star have the
following format:
${kep_id:0:4}/${kep_id}/kplr${kep_id}-${quarter_prefix}_${type}.fits,
where:
kep_id is the Kepler id left-padded with zeros to length 9;
quarter_prefix is the filename quarter prefix;
type is one of "llc" (long cadence light curve) or "slc" (short cadence
light curve).
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
"""
# Pad the Kepler id with zeros to length 9.
kep_id
=
"{:09d}"
.
format
(
int
(
kep_id
))
quarter_prefixes
,
cadence_suffix
=
((
LONG_CADENCE_QUARTER_PREFIXES
,
"llc"
)
if
long_cadence
else
(
SHORT_CADENCE_QUARTER_PREFIXES
,
"slc"
))
if
quarters
is
None
:
quarters
=
quarter_prefixes
.
keys
()
quarters
=
sorted
(
quarters
)
# Sort quarters chronologically.
filenames
=
[]
base_dir
=
os
.
path
.
join
(
base_dir
,
kep_id
[
0
:
4
],
kep_id
)
for
quarter
in
quarters
:
for
quarter_prefix
in
quarter_prefixes
[
quarter
]:
if
injected_group
:
base_name
=
"kplr{}-{}_INJECTED-{}_{}.fits"
.
format
(
kep_id
,
quarter_prefix
,
injected_group
,
cadence_suffix
)
else
:
base_name
=
"kplr{}-{}_{}.fits"
.
format
(
kep_id
,
quarter_prefix
,
cadence_suffix
)
filename
=
os
.
path
.
join
(
base_dir
,
base_name
)
# Not all stars have data for all quarters.
if
not
check_existence
or
gfile
.
Exists
(
filename
):
filenames
.
append
(
filename
)
return
filenames
def
scramble_light_curve
(
all_time
,
all_flux
,
all_quarters
,
scramble_type
):
"""Scrambles a light curve according to a given scrambling procedure.
Args:
all_time: List holding arrays of time values, each containing a quarter of
time data.
all_flux: List holding arrays of flux values, each containing a quarter of
flux data.
all_quarters: List of integers specifying which quarters are present in
the light curve (max is 18: Q0...Q17).
scramble_type: String specifying the scramble order, one of {'SCR1', 'SCR2',
'SCR3'}.
Returns:
scr_flux: Scrambled flux values; the same list as the input flux in another
order.
scr_time: Time values, re-partitioned to match sizes of the scr_flux lists.
"""
order
=
SIMULATED_DATA_SCRAMBLE_ORDERS
[
scramble_type
]
scr_flux
=
[]
for
quarter
in
order
:
# Ignore missing quarters in the scramble order.
if
quarter
in
all_quarters
:
scr_flux
.
append
(
all_flux
[
all_quarters
.
index
(
quarter
)])
scr_time
=
util
.
reshard_arrays
(
all_time
,
scr_flux
)
return
scr_time
,
scr_flux
def
read_kepler_light_curve
(
filenames
,
light_curve_extension
=
"LIGHTCURVE"
,
scramble_type
=
None
,
interpolate_missing_time
=
False
):
"""Reads time and flux measurements for a Kepler target star.
Args:
filenames: A list of .fits files containing time and flux measurements.
light_curve_extension: Name of the HDU 1 extension containing light curves.
scramble_type: What scrambling procedure to use: 'SCR1', 'SCR2', or 'SCR3'
(pg 9: https://exoplanetarchive.ipac.caltech.edu/docs/KSCI-19114-002.pdf).
interpolate_missing_time: Whether to interpolate missing (NaN) time values.
This should only affect the output if scramble_type is specified (NaN time
values typically come with NaN flux values, which are removed anyway, but
scrambing decouples NaN time values from NaN flux values).
Returns:
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays; the flux values of the light curve.
"""
all_time
=
[]
all_flux
=
[]
all_quarters
=
[]
for
filename
in
filenames
:
with
fits
.
open
(
gfile
.
Open
(
filename
,
"rb"
))
as
hdu_list
:
quarter
=
hdu_list
[
"PRIMARY"
].
header
[
"QUARTER"
]
light_curve
=
hdu_list
[
light_curve_extension
].
data
time
=
light_curve
.
TIME
flux
=
light_curve
.
PDCSAP_FLUX
if
not
time
.
size
:
continue
# No data.
# Possibly interpolate missing time values.
if
interpolate_missing_time
:
time
=
util
.
interpolate_missing_time
(
time
,
light_curve
.
CADENCENO
)
all_time
.
append
(
time
)
all_flux
.
append
(
flux
)
all_quarters
.
append
(
quarter
)
if
scramble_type
:
all_time
,
all_flux
=
scramble_light_curve
(
all_time
,
all_flux
,
all_quarters
,
scramble_type
)
# Remove timestamps with NaN time or flux values.
for
i
,
(
time
,
flux
)
in
enumerate
(
zip
(
all_time
,
all_flux
)):
flux_and_time_finite
=
np
.
logical_and
(
np
.
isfinite
(
flux
),
np
.
isfinite
(
time
))
all_time
[
i
]
=
time
[
flux_and_time_finite
]
all_flux
[
i
]
=
flux
[
flux_and_time_finite
]
return
all_time
,
all_flux
research/astronet/light_curve/kepler_io_test.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kepler_io.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os.path
from
absl
import
flags
from
absl.testing
import
absltest
import
numpy
as
np
from
light_curve
import
kepler_io
FLAGS
=
flags
.
FLAGS
_DATA_DIR
=
"light_curve/test_data/"
class
KeplerIoTest
(
absltest
.
TestCase
):
def
setUp
(
self
):
super
(
KeplerIoTest
,
self
).
setUp
()
self
.
data_dir
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
_DATA_DIR
)
def
testScrambleLightCurve
(
self
):
all_flux
=
[[
11
,
12
],
[
21
],
[
np
.
nan
,
np
.
nan
,
33
],
[
41
,
42
]]
all_time
=
[[
101
,
102
],
[
201
],
[
301
,
302
,
303
],
[
401
,
402
]]
all_quarters
=
[
3
,
4
,
7
,
14
]
scramble_type
=
"SCR1"
# New quarters order will be [14,7,3,4].
scr_time
,
scr_flux
=
kepler_io
.
scramble_light_curve
(
all_time
,
all_flux
,
all_quarters
,
scramble_type
)
# NaNs are not removed in this function.
gold_flux
=
[[
41
,
42
],
[
np
.
nan
,
np
.
nan
,
33
],
[
11
,
12
],
[
21
]]
gold_time
=
[[
101
,
102
],
[
201
,
301
,
302
],
[
303
,
401
],
[
402
]]
self
.
assertLen
(
gold_flux
,
len
(
scr_flux
))
self
.
assertLen
(
gold_time
,
len
(
scr_time
))
for
i
in
range
(
len
(
gold_flux
)):
np
.
testing
.
assert_array_equal
(
gold_flux
[
i
],
scr_flux
[
i
])
np
.
testing
.
assert_array_equal
(
gold_time
[
i
],
scr_time
[
i
])
def
testKeplerFilenames
(
self
):
# All quarters.
filenames
=
kepler_io
.
kepler_filenames
(
"/my/dir/"
,
1234567
,
check_existence
=
False
)
self
.
assertCountEqual
([
"/my/dir/0012/001234567/kplr001234567-2009131105131_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2009166043257_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2009259160929_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010174085026_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010265121752_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010355172524_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2011073133259_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2011177032512_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2011271113734_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2012004120508_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2012088054726_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2012179063303_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2012277125453_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2013011073258_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2013098041711_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2013131215648_llc.fits"
],
filenames
)
# Subset of quarters.
filenames
=
kepler_io
.
kepler_filenames
(
"/my/dir/"
,
1234567
,
quarters
=
[
3
,
4
],
check_existence
=
False
)
self
.
assertCountEqual
([
"/my/dir/0012/001234567/kplr001234567-2009350155506_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010078095331_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010009091648_llc.fits"
],
filenames
)
# Injected group.
filenames
=
kepler_io
.
kepler_filenames
(
"/my/dir/"
,
1234567
,
quarters
=
[
3
,
4
],
injected_group
=
"inj1"
,
check_existence
=
False
)
# pylint:disable=line-too-long
self
.
assertCountEqual
([
"/my/dir/0012/001234567/kplr001234567-2009350155506_INJECTED-inj1_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010078095331_INJECTED-inj1_llc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2010009091648_INJECTED-inj1_llc.fits"
],
filenames
)
# pylint:enable=line-too-long
# Short cadence.
filenames
=
kepler_io
.
kepler_filenames
(
"/my/dir/"
,
1234567
,
long_cadence
=
False
,
quarters
=
[
0
,
1
],
check_existence
=
False
)
self
.
assertCountEqual
([
"/my/dir/0012/001234567/kplr001234567-2009131110544_slc.fits"
,
"/my/dir/0012/001234567/kplr001234567-2009166044711_slc.fits"
],
filenames
)
# Check existence.
filenames
=
kepler_io
.
kepler_filenames
(
self
.
data_dir
,
11442793
,
check_existence
=
True
)
expected_filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
self
.
assertCountEqual
(
expected_filenames
,
filenames
)
def
testReadKeplerLightCurve
(
self
):
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
all_time
,
all_flux
=
kepler_io
.
read_kepler_light_curve
(
filenames
)
self
.
assertLen
(
all_time
,
3
)
self
.
assertLen
(
all_flux
,
3
)
self
.
assertLen
(
all_time
[
0
],
4134
)
self
.
assertLen
(
all_flux
[
0
],
4134
)
self
.
assertLen
(
all_time
[
1
],
1008
)
self
.
assertLen
(
all_flux
[
1
],
1008
)
self
.
assertLen
(
all_time
[
2
],
4486
)
self
.
assertLen
(
all_flux
[
2
],
4486
)
for
time
,
flux
in
zip
(
all_time
,
all_flux
):
self
.
assertTrue
(
np
.
isfinite
(
time
).
all
())
self
.
assertTrue
(
np
.
isfinite
(
flux
).
all
())
def
testReadKeplerLightCurveScrambled
(
self
):
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
all_time
,
all_flux
=
kepler_io
.
read_kepler_light_curve
(
filenames
,
scramble_type
=
"SCR1"
)
self
.
assertLen
(
all_time
,
3
)
self
.
assertLen
(
all_flux
,
3
)
# Arrays are shorter than above due to separation of time and flux NaNs.
self
.
assertLen
(
all_time
[
0
],
4344
)
self
.
assertLen
(
all_flux
[
0
],
4344
)
self
.
assertLen
(
all_time
[
1
],
4041
)
self
.
assertLen
(
all_flux
[
1
],
4041
)
self
.
assertLen
(
all_time
[
2
],
1008
)
self
.
assertLen
(
all_flux
[
2
],
1008
)
for
time
,
flux
in
zip
(
all_time
,
all_flux
):
self
.
assertTrue
(
np
.
isfinite
(
time
).
all
())
self
.
assertTrue
(
np
.
isfinite
(
flux
).
all
())
def
testReadKeplerLightCurveScrambledInterpolateMissingTime
(
self
):
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
all_time
,
all_flux
=
kepler_io
.
read_kepler_light_curve
(
filenames
,
scramble_type
=
"SCR1"
,
interpolate_missing_time
=
True
)
self
.
assertLen
(
all_time
,
3
)
self
.
assertLen
(
all_flux
,
3
)
self
.
assertLen
(
all_time
[
0
],
4486
)
self
.
assertLen
(
all_flux
[
0
],
4486
)
self
.
assertLen
(
all_time
[
1
],
4134
)
self
.
assertLen
(
all_flux
[
1
],
4134
)
self
.
assertLen
(
all_time
[
2
],
1008
)
self
.
assertLen
(
all_flux
[
2
],
1008
)
for
time
,
flux
in
zip
(
all_time
,
all_flux
):
self
.
assertTrue
(
np
.
isfinite
(
time
).
all
())
self
.
assertTrue
(
np
.
isfinite
(
flux
).
all
())
if
__name__
==
"__main__"
:
FLAGS
.
test_srcdir
=
""
absltest
.
main
()
research/astronet/light_curve/median_filter.py
deleted
100644 → 0
View file @
17c2f0cc
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility function for smoothing data using a median filter."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
def
median_filter
(
x
,
y
,
num_bins
,
bin_width
=
None
,
x_min
=
None
,
x_max
=
None
):
"""Computes the median y-value in uniform intervals (bins) along the x-axis.
The interval [x_min, x_max) is divided into num_bins uniformly spaced
intervals of width bin_width. The value computed for each bin is the median
of all y-values whose corresponding x-value is in the interval.
NOTE: x must be sorted in ascending order or the results will be incorrect.
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
spaced bins on the x-axis.
Raises:
ValueError: If an argument has an inappropriate value.
"""
if
num_bins
<
2
:
raise
ValueError
(
"num_bins must be at least 2. Got: {}"
.
format
(
num_bins
))
# Validate the lengths of x and y.
x_len
=
len
(
x
)
if
x_len
<
2
:
raise
ValueError
(
"len(x) must be at least 2. Got: {}"
.
format
(
x_len
))
if
x_len
!=
len
(
y
):
raise
ValueError
(
"len(x) (got: {}) must equal len(y) (got: {})"
.
format
(
x_len
,
len
(
y
)))
# Validate x_min and x_max.
x_min
=
x_min
if
x_min
is
not
None
else
x
[
0
]
x_max
=
x_max
if
x_max
is
not
None
else
x
[
-
1
]
if
x_min
>=
x_max
:
raise
ValueError
(
"x_min (got: {}) must be less than x_max (got: {})"
.
format
(
x_min
,
x_max
))
if
x_min
>
x
[
-
1
]:
raise
ValueError
(
"x_min (got: {}) must be less than or equal to the largest value of x "
"(got: {})"
.
format
(
x_min
,
x
[
-
1
]))
# Validate bin_width.
bin_width
=
bin_width
if
bin_width
is
not
None
else
(
x_max
-
x_min
)
/
num_bins
if
bin_width
<=
0
:
raise
ValueError
(
"bin_width must be positive. Got: {}"
.
format
(
bin_width
))
if
bin_width
>=
x_max
-
x_min
:
raise
ValueError
(
"bin_width (got: {}) must be less than x_max - x_min (got: {})"
.
format
(
bin_width
,
x_max
-
x_min
))
bin_spacing
=
(
x_max
-
x_min
-
bin_width
)
/
(
num_bins
-
1
)
# Bins with no y-values will fall back to the global median.
result
=
np
.
repeat
(
np
.
median
(
y
),
num_bins
)
# Find the first element of x >= x_min. This loop is guaranteed to produce
# a valid index because we know that x_min <= x[-1].
x_start
=
0
while
x
[
x_start
]
<
x_min
:
x_start
+=
1
# The bin at index i is the median of all elements y[j] such that
# bin_min <= x[j] < bin_max, where bin_min and bin_max are the endpoints of
# bin i.
bin_min
=
x_min
# Left endpoint of the current bin.
bin_max
=
x_min
+
bin_width
# Right endpoint of the current bin.
j_start
=
x_start
# Inclusive left index of the current bin.
j_end
=
x_start
# Exclusive end index of the current bin.
for
i
in
range
(
num_bins
):
# Move j_start to the first index of x >= bin_min.
while
j_start
<
x_len
and
x
[
j_start
]
<
bin_min
:
j_start
+=
1
# Move j_end to the first index of x >= bin_max (exclusive end index).
while
j_end
<
x_len
and
x
[
j_end
]
<
bin_max
:
j_end
+=
1
if
j_end
>
j_start
:
# Compute and insert the median bin value.
result
[
i
]
=
np
.
median
(
y
[
j_start
:
j_end
])
# Advance the bin.
bin_min
+=
bin_spacing
bin_max
+=
bin_spacing
return
result
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment