Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
650f0a3d
Commit
650f0a3d
authored
Sep 21, 2018
by
Alex Tamkin
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Module for generating synthetic light curves with periodic transit-like dips.
PiperOrigin-RevId: 214069349
parent
252e2d2e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
248 additions
and
0 deletions
+248
-0
research/astronet/astrowavenet/data/synthetic_transit_maker.py
...rch/astronet/astrowavenet/data/synthetic_transit_maker.py
+138
-0
research/astronet/astrowavenet/data/synthetic_transit_maker_test.py
...stronet/astrowavenet/data/synthetic_transit_maker_test.py
+110
-0
No files found.
research/astronet/astrowavenet/data/synthetic_transit_maker.py
0 → 100644
View file @
650f0a3d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates synthetic light curves with periodic transit-like dips.
See class docstring below for more information.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
class
SyntheticTransitMaker
(
object
):
"""Generates synthetic light curves with periodic transit-like dips.
These light curves are generated by thresholding noisy sine waves. Each time
random_light_curve is called, a thresholded sine wave is generated by sampling
parameters uniformly from the ranges specified below.
Attributes:
period_range: A tuple of positive values specifying the range of periods the
sine waves may take.
amplitude_range: A tuple of positive values specifying the range of
amplitudes the sine waves may take.
threshold_ratio_range: A tuple of values in [0, 1) specifying the range of
thresholds as a ratio of the sine wave amplitude.
phase_range: Tuple of values specifying the range of phases the sine wave
may take as a ratio of the sampled period. E.g. a sampled phase of 0.5
would translate the sine wave by half of the period. The most common
reason to override this would be to generate light curves
deterministically (with e.g. (0,0)).
noise_sd_range: A tuple of values in [0, 1) specifying the range of
standard deviations for the Gaussian noise applied to the sine wave.
"""
def
__init__
(
self
,
period_range
=
(
0.5
,
4
),
amplitude_range
=
(
1
,
1
),
threshold_ratio_range
=
(
0
,
0.99
),
phase_range
=
(
0
,
1
),
noise_sd_range
=
(
0.1
,
0.1
)):
if
threshold_ratio_range
[
0
]
<
0
or
threshold_ratio_range
[
1
]
>=
1
:
raise
ValueError
(
"Threshold ratio range must be in [0, 1). Got: {}."
.
format
(
threshold_ratio_range
))
if
amplitude_range
[
0
]
<=
0
:
raise
ValueError
(
"Amplitude range must only contain positive numbers. Got: {}."
.
format
(
amplitude_range
))
if
period_range
[
0
]
<=
0
:
raise
ValueError
(
"Period range must only contain positive numbers. Got: {}."
.
format
(
period_range
))
if
noise_sd_range
[
0
]
<
0
:
raise
ValueError
(
"Noise standard deviation range must be nonnegative. Got: {}."
.
format
(
noise_sd_range
))
for
(
start
,
end
),
name
in
[(
period_range
,
"period"
),
(
amplitude_range
,
"amplitude"
),
(
threshold_ratio_range
,
"threshold ratio"
),
(
phase_range
,
"phase range"
),
(
noise_sd_range
,
"noise standard deviation"
)]:
if
end
<
start
:
raise
ValueError
(
"End of {} range may not be less than start. Got: ({}, {})"
.
format
(
name
,
start
,
end
))
self
.
period_range
=
period_range
self
.
amplitude_range
=
amplitude_range
self
.
threshold_ratio_range
=
threshold_ratio_range
self
.
phase_range
=
phase_range
self
.
noise_sd_range
=
noise_sd_range
def
random_light_curve
(
self
,
time
,
mask_prob
=
0
):
"""Samples parameters and generates a light curve.
Args:
time: np.array, x-values to sample from the thresholded sine wave.
mask_prob: value in [0,1], probability an individual datapoint is set to
zero
Returns:
flux: np.array, values of the masked sampled light curve corresponding to
the provided time array.
mask: np.array of ones and zeros, with zeros indicating masking at the
respective position on the flux array.
"""
period
=
np
.
random
.
uniform
(
*
self
.
period_range
)
phase
=
np
.
random
.
uniform
(
*
self
.
phase_range
)
*
period
amplitude
=
np
.
random
.
uniform
(
*
self
.
amplitude_range
)
threshold
=
np
.
random
.
uniform
(
*
self
.
threshold_ratio_range
)
*
amplitude
sin_wave
=
np
.
sin
(
time
/
period
-
phase
)
*
amplitude
flux
=
np
.
minimum
(
sin_wave
,
-
threshold
)
+
threshold
noise_sd
=
np
.
random
.
uniform
(
*
self
.
noise_sd_range
)
noise
=
np
.
random
.
normal
(
scale
=
noise_sd
,
size
=
(
len
(
time
),))
flux
+=
noise
# Array of ones and zeros, where zeros indicate masking.
mask
=
np
.
random
.
random
(
len
(
time
))
>
mask_prob
mask
=
mask
.
astype
(
np
.
float
)
return
flux
*
mask
,
mask
def
random_light_curve_generator
(
self
,
time
,
mask_prob
=
0
):
"""Returns a generator function yielding random light curves.
Args:
time: An np.array of x-values to sample from the thresholded sine wave.
mask_prob: Value in [0,1], probability an individual datapoint is set to
zero.
Returns:
A generator yielding random light curves.
"""
def
generator_fn
():
while
True
:
yield
self
.
random_light_curve
(
time
,
mask_prob
)
return
generator_fn
research/astronet/astrowavenet/data/synthetic_transit_maker_test.py
0 → 100644
View file @
650f0a3d
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for synthetic_transit_maker."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
import
numpy
as
np
from
astrowavenet.data
import
synthetic_transit_maker
class
SyntheticTransitMakerTest
(
absltest
.
TestCase
):
def
testBadRangesRaiseExceptions
(
self
):
# Period range cannot contain negative values.
with
self
.
assertRaisesRegexp
(
ValueError
,
'Period'
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
period_range
=
(
-
1
,
10
))
# Amplitude range cannot contain negative values.
with
self
.
assertRaisesRegexp
(
ValueError
,
'Amplitude'
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
amplitude_range
=
(
-
10
,
-
1
))
# Threshold ratio range must be contained in the half-open interval [0, 1).
with
self
.
assertRaisesRegexp
(
ValueError
,
'Threshold ratio'
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
threshold_ratio_range
=
(
0
,
1
))
# Noise standard deviation range must only contain nonnegative values.
with
self
.
assertRaisesRegexp
(
ValueError
,
'Noise standard deviation'
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
noise_sd_range
=
(
-
1
,
1
))
# End of range may not be less than start.
invalid_range
=
(
0.2
,
0.1
)
range_args
=
[
'period_range'
,
'threshold_ratio_range'
,
'amplitude_range'
,
'noise_sd_range'
,
'phase_range'
]
for
range_arg
in
range_args
:
with
self
.
assertRaisesRegexp
(
ValueError
,
'may not be less'
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
**
{
range_arg
:
invalid_range
})
def
testStochasticLightCurveGeneration
(
self
):
transit_maker
=
synthetic_transit_maker
.
SyntheticTransitMaker
()
time
=
np
.
arange
(
100
)
flux
,
mask
=
transit_maker
.
random_light_curve
(
time
,
mask_prob
=
0.4
)
self
.
assertEqual
(
len
(
flux
),
100
)
self
.
assertEqual
(
len
(
mask
),
100
)
def
testDeterministicLightCurveGeneration
(
self
):
gold_flux
=
np
.
array
([
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
0.85099258
,
-
2.04776251
,
-
2.65829632
,
-
2.53014378
,
-
1.69530454
,
-
0.36223792
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
0.2110405
,
-
1.57757635
,
-
2.47528153
,
-
2.67999913
,
-
2.14061117
,
-
0.9918028
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
1.01475559
,
-
2.15534176
,
-
2.68282928
,
-
2.46550457
,
-
1.55763357
,
-
0.18591162
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
0.3870683
,
-
1.71426199
,
-
2.53849461
,
-
2.65395535
,
-
2.03181367
,
-
0.82741829
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
1.17380391
,
-
2.2541162
,
-
2.69666588
,
-
2.39094831
,
-
1.41330116
,
-
0.00784284
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
0.56063229
,
-
1.84372452
,
-
2.59152891
,
-
2.61731875
,
-
1.91465433
,
-
0.65899089
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
1.3275672
,
-
2.34373163
,
-
2.69975648
,
-
2.30674237
,
-
1.26282489
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
-
0.73111006
,
-
1.9654997
,
-
2.63419424
,
-
2.5702207
,
-
1.78955328
,
-
0.48712456
])
# Use ranges containing one value for determinism.
transit_maker
=
synthetic_transit_maker
.
SyntheticTransitMaker
(
period_range
=
(
2
,
2
),
amplitude_range
=
(
3
,
3
),
threshold_ratio_range
=
(.
1
,
.
1
),
phase_range
=
(
0
,
0
),
noise_sd_range
=
(
0
,
0
))
time
=
np
.
linspace
(
0
,
100
,
100
)
flux
,
mask
=
transit_maker
.
random_light_curve
(
time
)
self
.
assertAllClose
(
flux
,
gold_flux
)
self
.
assertAllClose
(
mask
,
np
.
ones
(
100
))
def
testRandomLightCurveGenerator
(
self
):
transit_maker
=
synthetic_transit_maker
.
SyntheticTransitMaker
()
time
=
np
.
linspace
(
0
,
100
,
100
)
generator
=
transit_maker
.
random_light_curve_generator
(
time
,
mask_prob
=
0.3
)()
for
_
in
range
(
5
):
flux
,
mask
=
next
(
generator
)
self
.
assertEqual
(
len
(
flux
),
100
)
self
.
assertEqual
(
len
(
mask
),
100
)
if
__name__
==
'__main__'
:
absltest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment