Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
120bc915
Commit
120bc915
authored
Jun 08, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Jun 14, 2018
Browse files
Handle single segments and empty arrays more elegantly.
PiperOrigin-RevId: 199862472
parent
7f6313ce
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
21 deletions
+22
-21
research/astronet/light_curve_util/util.py
research/astronet/light_curve_util/util.py
+18
-20
research/astronet/light_curve_util/util_test.py
research/astronet/light_curve_util/util_test.py
+4
-1
No files found.
research/astronet/light_curve_util/util.py
View file @
120bc915
...
@@ -18,8 +18,6 @@ from __future__ import absolute_import
...
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
# pylint:disable=redefined-builtin
from
six.moves
import
range
# pylint:disable=redefined-builtin
...
@@ -51,21 +49,21 @@ def split(all_time, all_flux, gap_width=0.75):
...
@@ -51,21 +49,21 @@ def split(all_time, all_flux, gap_width=0.75):
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
piecewise defined (e.g. split by quarter breaks or gaps in the in the data).
Args:
Args:
all_time: Numpy array or
list
of numpy arrays; each is a sequence of
time
all_time: Numpy array or
sequence
of numpy arrays; each is a sequence of
values.
time
values.
all_flux: Numpy array or
list
of numpy arrays; each is a sequence of
flux
all_flux: Numpy array or
sequence
of numpy arrays; each is a sequence of
values of the corresponding time array.
flux
values of the corresponding time array.
gap_width: Minimum gap size (in time units) for a split.
gap_width: Minimum gap size (in time units) for a split.
Returns:
Returns:
out_time: List of numpy arrays; the split time arrays.
out_time: List of numpy arrays; the split time arrays.
out_flux: List of numpy arrays; the split flux arrays.
out_flux: List of numpy arrays; the split flux arrays.
"""
"""
all_time
=
np
.
array
(
all_time
)
all_flux
=
np
.
array
(
all_flux
)
# Handle single-segment inputs.
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
if
all_time
.
ndim
==
1
:
# to bool fails if all_time is a numpy array, and all_time.size is not defined
# if all_time is a list of numpy arrays.
if
len
(
all_time
)
>
0
and
not
isinstance
(
all_time
[
0
],
collections
.
Iterable
):
# pylint:disable=g-explicit-length-test
all_time
=
[
all_time
]
all_time
=
[
all_time
]
all_flux
=
[
all_flux
]
all_flux
=
[
all_flux
]
...
@@ -90,10 +88,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
...
@@ -90,10 +88,10 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
curve (e.g. one that is split by quarter breaks or gaps in the in the data).
Args:
Args:
all_time: Numpy array or
list
of numpy arrays; each is a sequence of
time
all_time: Numpy array or
sequence
of numpy arrays; each is a sequence of
values.
time
values.
all_flux: Numpy array or
list
of numpy arrays; each is a sequence of
flux
all_flux: Numpy array or
sequence
of numpy arrays; each is a sequence of
values of the corresponding time array.
flux
values of the corresponding time array.
events: List of Event objects to remove.
events: List of Event objects to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
width_factor: Fractional multiplier of the duration of each event to remove.
...
@@ -103,11 +101,11 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
...
@@ -103,11 +101,11 @@ def remove_events(all_time, all_flux, events, width_factor=1.0):
output_flux: Numpy array or list of numpy arrays; the flux arrays with
output_flux: Numpy array or list of numpy arrays; the flux arrays with
events removed.
events removed.
"""
"""
all_time
=
np
.
array
(
all_time
)
all_flux
=
np
.
array
(
all_flux
)
# Handle single-segment inputs.
# Handle single-segment inputs.
# We must use an explicit length test on all_time because implicit conversion
if
all_time
.
ndim
==
1
:
# to bool fails if all_time is a numpy array and all_time.size is not defined
# if all_time is a list of numpy arrays.
if
len
(
all_time
)
>
0
and
not
isinstance
(
all_time
[
0
],
collections
.
Iterable
):
# pylint:disable=g-explicit-length-test
all_time
=
[
all_time
]
all_time
=
[
all_time
]
all_flux
=
[
all_flux
]
all_flux
=
[
all_flux
]
single_segment
=
True
single_segment
=
True
...
@@ -150,10 +148,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
...
@@ -150,10 +148,10 @@ def interpolate_masked_spline(all_time, all_masked_time, all_masked_spline):
interp_spline
=
[]
interp_spline
=
[]
for
time
,
masked_time
,
masked_spline
in
zip
(
for
time
,
masked_time
,
masked_spline
in
zip
(
all_time
,
all_masked_time
,
all_masked_spline
):
all_time
,
all_masked_time
,
all_masked_spline
):
if
len
(
masked_time
)
>
0
:
# pylint:disable=g-explicit-length-test
if
masked_time
.
size
:
interp_spline
.
append
(
np
.
interp
(
time
,
masked_time
,
masked_spline
))
interp_spline
.
append
(
np
.
interp
(
time
,
masked_time
,
masked_spline
))
else
:
else
:
interp_spline
.
append
(
np
.
full_like
(
time
,
np
.
nan
))
interp_spline
.
append
(
np
.
array
([
np
.
nan
]
*
len
(
time
)
))
return
interp_spline
return
interp_spline
...
...
research/astronet/light_curve_util/util_test.py
View file @
120bc915
...
@@ -136,20 +136,23 @@ class LightCurveUtilTest(absltest.TestCase):
...
@@ -136,20 +136,23 @@ class LightCurveUtilTest(absltest.TestCase):
all_time
=
[
all_time
=
[
np
.
arange
(
0
,
10
,
dtype
=
np
.
float
),
np
.
arange
(
0
,
10
,
dtype
=
np
.
float
),
np
.
arange
(
10
,
20
,
dtype
=
np
.
float
),
np
.
arange
(
10
,
20
,
dtype
=
np
.
float
),
np
.
arange
(
20
,
30
,
dtype
=
np
.
float
),
]
]
all_masked_time
=
[
all_masked_time
=
[
np
.
array
([
0
,
1
,
2
,
3
,
8
,
9
],
dtype
=
np
.
float
),
# No 4, 5, 6, 7
np
.
array
([
0
,
1
,
2
,
3
,
8
,
9
],
dtype
=
np
.
float
),
# No 4, 5, 6, 7
np
.
array
([
10
,
11
,
12
,
13
,
14
,
15
,
16
],
dtype
=
np
.
float
),
# No 17, 18, 19
np
.
array
([
10
,
11
,
12
,
13
,
14
,
15
,
16
],
dtype
=
np
.
float
),
# No 17, 18, 19
np
.
array
([],
dtype
=
np
.
float
)
]
]
all_masked_spline
=
[
2
*
t
+
100
for
t
in
all_masked_time
]
all_masked_spline
=
[
2
*
t
+
100
for
t
in
all_masked_time
]
interp_spline
=
util
.
interpolate_masked_spline
(
all_time
,
all_masked_time
,
interp_spline
=
util
.
interpolate_masked_spline
(
all_time
,
all_masked_time
,
all_masked_spline
)
all_masked_spline
)
self
.
assertLen
(
interp_spline
,
2
)
self
.
assertLen
(
interp_spline
,
3
)
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
[
100
,
102
,
104
,
106
,
108
,
110
,
112
,
114
,
116
,
118
],
interp_spline
[
0
])
[
100
,
102
,
104
,
106
,
108
,
110
,
112
,
114
,
116
,
118
],
interp_spline
[
0
])
self
.
assertSequenceAlmostEqual
(
self
.
assertSequenceAlmostEqual
(
[
120
,
122
,
124
,
126
,
128
,
130
,
132
,
132
,
132
,
132
],
interp_spline
[
1
])
[
120
,
122
,
124
,
126
,
128
,
130
,
132
,
132
,
132
,
132
],
interp_spline
[
1
])
self
.
assertTrue
(
np
.
all
(
np
.
isnan
(
interp_spline
[
2
])))
def
testCountTransitPoints
(
self
):
def
testCountTransitPoints
(
self
):
time
=
np
.
concatenate
([
time
=
np
.
concatenate
([
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment