Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9d0f41b7
Commit
9d0f41b7
authored
Sep 17, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Replace '%' string formatting with .format().
PiperOrigin-RevId: 213353962
parent
313d0c41
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
76 additions
and
101 deletions
+76
-101
research/astronet/astronet/astro_cnn_model/astro_cnn_model.py
...arch/astronet/astronet/astro_cnn_model/astro_cnn_model.py
+2
-20
research/astronet/astronet/astro_fc_model/astro_fc_model.py
research/astronet/astronet/astro_fc_model/astro_fc_model.py
+3
-21
research/astronet/astronet/astro_model/astro_model.py
research/astronet/astronet/astro_model/astro_model.py
+2
-2
research/astronet/astronet/data/generate_input_records.py
research/astronet/astronet/data/generate_input_records.py
+5
-3
research/astronet/astronet/data/preprocess.py
research/astronet/astronet/data/preprocess.py
+2
-2
research/astronet/astronet/models.py
research/astronet/astronet/models.py
+5
-4
research/astronet/astronet/ops/dataset_ops.py
research/astronet/astronet/ops/dataset_ops.py
+5
-5
research/astronet/astronet/ops/training.py
research/astronet/astronet/ops/training.py
+1
-1
research/astronet/astronet/util/config_util.py
research/astronet/astronet/util/config_util.py
+6
-5
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+2
-2
research/astronet/astronet/util/example_util.py
research/astronet/astronet/util/example_util.py
+5
-3
research/astronet/light_curve_util/cc/python/postproc.py
research/astronet/light_curve_util/cc/python/postproc.py
+2
-1
research/astronet/light_curve_util/kepler_io.py
research/astronet/light_curve_util/kepler_io.py
+5
-6
research/astronet/light_curve_util/kepler_io_test.py
research/astronet/light_curve_util/kepler_io_test.py
+12
-8
research/astronet/light_curve_util/median_filter.py
research/astronet/light_curve_util/median_filter.py
+11
-11
research/astronet/light_curve_util/util.py
research/astronet/light_curve_util/util.py
+2
-2
research/astronet/third_party/kepler_spline/kepler_spline.py
research/astronet/third_party/kepler_spline/kepler_spline.py
+6
-5
No files found.
research/astronet/astronet/astro_cnn_model/astro_cnn_model.py
View file @
9d0f41b7
...
...
@@ -54,24 +54,6 @@ from astronet.astro_model import astro_model
class
AstroCNNModel
(
astro_model
.
AstroModel
):
"""A model for classifying light curves using a convolutional neural net."""
def
__init__
(
self
,
features
,
labels
,
hparams
,
mode
):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
super
(
AstroCNNModel
,
self
).
__init__
(
features
,
labels
,
hparams
,
mode
)
def
_build_cnn_layers
(
self
,
inputs
,
hparams
,
scope
=
"cnn"
):
"""Builds convolutional layers.
...
...
@@ -95,7 +77,7 @@ class AstroCNNModel(astro_model.AstroModel):
for
i
in
range
(
hparams
.
cnn_num_blocks
):
num_filters
=
int
(
hparams
.
cnn_initial_num_filters
*
hparams
.
cnn_block_filter_factor
**
i
)
with
tf
.
variable_scope
(
"block_
%d"
%
(
i
+
1
)):
with
tf
.
variable_scope
(
"block_
{}"
.
format
(
i
+
1
)):
for
j
in
range
(
hparams
.
cnn_block_size
):
net
=
tf
.
layers
.
conv1d
(
inputs
=
net
,
...
...
@@ -103,7 +85,7 @@ class AstroCNNModel(astro_model.AstroModel):
kernel_size
=
int
(
hparams
.
cnn_kernel_size
),
padding
=
hparams
.
convolution_padding
,
activation
=
tf
.
nn
.
relu
,
name
=
"conv_
%d"
%
(
j
+
1
))
name
=
"conv_
{}"
.
format
(
j
+
1
))
if
hparams
.
pool_size
>
1
:
# pool_size 0 or 1 denotes no pooling
net
=
tf
.
layers
.
max_pooling1d
(
...
...
research/astronet/astronet/astro_fc_model/astro_fc_model.py
View file @
9d0f41b7
...
...
@@ -58,24 +58,6 @@ from astronet.astro_model import astro_model
class
AstroFCModel
(
astro_model
.
AstroModel
):
"""A model for classifying light curves using fully connected layers."""
def
__init__
(
self
,
features
,
labels
,
hparams
,
mode
):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
super
(
AstroFCModel
,
self
).
__init__
(
features
,
labels
,
hparams
,
mode
)
def
_build_local_fc_layers
(
self
,
inputs
,
hparams
,
scope
):
"""Builds locally fully connected layers.
...
...
@@ -120,8 +102,8 @@ class AstroFCModel(astro_model.AstroModel):
elif
hparams
.
pooling_type
==
"avg"
:
net
=
tf
.
reduce_mean
(
net
,
axis
=
1
,
name
=
"avg_pool"
)
else
:
raise
ValueError
(
"Unrecognized pooling_type: %s"
%
hparams
.
pooling_type
)
raise
ValueError
(
"Unrecognized pooling_type: {}"
.
format
(
hparams
.
pooling_type
)
)
remaining_layers
=
hparams
.
num_local_layers
-
1
else
:
...
...
@@ -133,7 +115,7 @@ class AstroFCModel(astro_model.AstroModel):
inputs
=
net
,
num_outputs
=
hparams
.
local_layer_size
,
activation_fn
=
tf
.
nn
.
relu
,
scope
=
"fully_connected_
%d"
%
(
i
+
1
))
scope
=
"fully_connected_
{}"
.
format
(
i
+
1
))
if
hparams
.
dropout_rate
>
0
:
net
=
tf
.
layers
.
dropout
(
...
...
research/astronet/astronet/astro_model/astro_model.py
View file @
9d0f41b7
...
...
@@ -93,7 +93,7 @@ class AstroModel(object):
tf
.
estimator
.
ModeKeys
.
PREDICT
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Expected mode in
%s
. Got:
%s"
%
(
valid_modes
,
mode
))
raise
ValueError
(
"Expected mode in
{}
. Got:
{}"
.
format
(
valid_modes
,
mode
))
self
.
hparams
=
hparams
self
.
mode
=
mode
...
...
@@ -213,7 +213,7 @@ class AstroModel(object):
inputs
=
net
,
units
=
self
.
hparams
.
pre_logits_hidden_layer_size
,
activation
=
tf
.
nn
.
relu
,
name
=
"fully_connected_
%s"
%
(
i
+
1
))
name
=
"fully_connected_
{}"
.
format
(
i
+
1
))
if
self
.
hparams
.
pre_logits_dropout_rate
>
0
:
net
=
tf
.
layers
.
dropout
(
...
...
research/astronet/astronet/data/generate_input_records.py
View file @
9d0f41b7
...
...
@@ -100,7 +100,7 @@ parser.add_argument(
required
=
True
,
help
=
"CSV file containing the Q1-Q17 DR24 Kepler TCE table. Must contain "
"columns: rowid, kepid, tce_plnt_num, tce_period, tce_duration, "
"tce_time0bk. Download from:
%s"
%
_DR24_TCE_URL
)
"tce_time0bk. Download from:
{}"
.
format
(
_DR24_TCE_URL
)
)
parser
.
add_argument
(
"--kepler_data_dir"
,
...
...
@@ -219,8 +219,10 @@ def main(argv):
for
i
in
range
(
FLAGS
.
num_train_shards
):
start
=
boundaries
[
i
]
end
=
boundaries
[
i
+
1
]
file_shards
.
append
((
train_tces
[
start
:
end
],
os
.
path
.
join
(
FLAGS
.
output_dir
,
"train-%.5d-of-%.5d"
%
(
i
,
FLAGS
.
num_train_shards
))))
filename
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"train-{:05d}-of-{:05d}"
.
format
(
i
,
FLAGS
.
num_train_shards
))
file_shards
.
append
((
train_tces
[
start
:
end
],
filename
))
# Validation and test sets each have a single shard.
file_shards
.
append
((
val_tces
,
os
.
path
.
join
(
FLAGS
.
output_dir
,
...
...
research/astronet/astronet/data/preprocess.py
View file @
9d0f41b7
...
...
@@ -47,8 +47,8 @@ def read_light_curve(kepid, kepler_data_dir):
# Read the Kepler light curve.
file_names
=
kepler_io
.
kepler_filenames
(
kepler_data_dir
,
kepid
)
if
not
file_names
:
raise
IOError
(
"Failed to find .fits files in
%s
for Kepler ID
%s"
%
(
kepler_data_dir
,
kepid
))
raise
IOError
(
"Failed to find .fits files in
{}
for Kepler ID
{}"
.
format
(
kepler_data_dir
,
kepid
))
return
kepler_io
.
read_kepler_light_curve
(
file_names
)
...
...
research/astronet/astronet/models.py
View file @
9d0f41b7
...
...
@@ -46,7 +46,7 @@ def get_model_class(model_name):
ValueError: If model_name is unrecognized.
"""
if
model_name
not
in
_MODELS
:
raise
ValueError
(
"Unrecognized model name:
%s"
%
model_name
)
raise
ValueError
(
"Unrecognized model name:
{}"
.
format
(
model_name
)
)
return
_MODELS
[
model_name
][
0
]
...
...
@@ -67,11 +67,12 @@ def get_model_config(model_name, config_name):
ValueError: If model_name or config_name is unrecognized.
"""
if
model_name
not
in
_MODELS
:
raise
ValueError
(
"Unrecognized model name:
%s"
%
model_name
)
raise
ValueError
(
"Unrecognized model name:
{}"
.
format
(
model_name
)
)
config_module
=
_MODELS
[
model_name
][
1
]
try
:
return
getattr
(
config_module
,
config_name
)()
except
AttributeError
:
raise
ValueError
(
"Config name '%s' not found in configuration module: %s"
%
(
config_name
,
config_module
.
__name__
))
raise
ValueError
(
"Config name '{}' not found in configuration module: {}"
.
format
(
config_name
,
config_module
.
__name__
))
research/astronet/astronet/ops/dataset_ops.py
View file @
9d0f41b7
...
...
@@ -69,7 +69,7 @@ def _recursive_pad_to_batch_size(tensor_or_collection, batch_size):
for
t
in
tensor_or_collection
]
raise
ValueError
(
"Unknown input type:
%s"
%
tensor_or_collection
)
raise
ValueError
(
"Unknown input type:
{}"
.
format
(
tensor_or_collection
)
)
def
pad_dataset_to_batch_size
(
dataset
,
batch_size
):
...
...
@@ -119,7 +119,7 @@ def _recursive_set_batch_size(tensor_or_collection, batch_size):
for
t
in
tensor_or_collection
:
_recursive_set_batch_size
(
t
,
batch_size
)
else
:
raise
ValueError
(
"Unknown input type:
%s"
%
tensor_or_collection
)
raise
ValueError
(
"Unknown input type:
{}"
.
format
(
tensor_or_collection
)
)
return
tensor_or_collection
...
...
@@ -170,7 +170,7 @@ def build_dataset(file_pattern,
for
p
in
file_patterns
:
matches
=
tf
.
gfile
.
Glob
(
p
)
if
not
matches
:
raise
ValueError
(
"Found no input files matching
%s"
%
p
)
raise
ValueError
(
"Found no input files matching
{}"
.
format
(
p
)
)
filenames
.
extend
(
matches
)
tf
.
logging
.
info
(
"Building input pipeline from %d files matching patterns: %s"
,
len
(
filenames
),
file_patterns
)
...
...
@@ -180,8 +180,8 @@ def build_dataset(file_pattern,
label_ids
=
set
(
input_config
.
label_map
.
values
())
if
label_ids
!=
set
(
range
(
len
(
label_ids
))):
raise
ValueError
(
"Label IDs must be contiguous integers starting at 0. Got:
%s"
%
label_ids
)
"Label IDs must be contiguous integers starting at 0. Got:
{}"
.
format
(
label_ids
)
)
# Create a HashTable mapping label strings to integer ids.
table_initializer
=
tf
.
contrib
.
lookup
.
KeyValueTensorInitializer
(
...
...
research/astronet/astronet/ops/training.py
View file @
9d0f41b7
...
...
@@ -74,7 +74,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False):
elif
optimizer_name
==
"rmsprop"
:
optimizer
=
tf
.
RMSPropOptimizer
(
learning_rate
)
else
:
raise
ValueError
(
"Unknown optimizer:
%s"
%
hparams
.
optimizer
)
raise
ValueError
(
"Unknown optimizer:
{}"
.
format
(
hparams
.
optimizer
)
)
if
use_tpu
:
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
...
...
research/astronet/astronet/util/config_util.py
View file @
9d0f41b7
...
...
@@ -49,14 +49,15 @@ def parse_json(json_string_or_file):
with
tf
.
gfile
.
Open
(
json_string_or_file
)
as
f
:
json_dict
=
json
.
load
(
f
)
except
ValueError
as
json_file_parsing_error
:
raise
ValueError
(
"Unable to parse the content of the json file %s. "
"Parsing error: %s."
%
(
json_string_or_file
,
json_file_parsing_error
.
message
))
raise
ValueError
(
"Unable to parse the content of the json file {}. "
"Parsing error: {}."
.
format
(
json_string_or_file
,
json_file_parsing_error
.
message
))
except
tf
.
gfile
.
FileError
:
message
=
(
"Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.
\n
"
"JSON parsing error:
%s
\n\n
Input parameter:
\n
%s."
%
(
literal_json_parsing_error
.
message
,
json_string_or_file
))
"JSON parsing error:
{}
\n\n
Input parameter:
\n
{}."
.
format
(
literal_json_parsing_error
.
message
,
json_string_or_file
))
raise
ValueError
(
message
)
return
json_dict
...
...
research/astronet/astronet/util/estimator_util.py
View file @
9d0f41b7
...
...
@@ -132,8 +132,8 @@ class _ModelFn(object):
if
"labels"
in
features
:
if
labels
is
not
None
and
labels
is
not
features
[
"labels"
]:
raise
ValueError
(
"Conflicting labels: features['labels'] =
%s
, labels =
%s"
%
(
features
[
"labels"
],
labels
))
"Conflicting labels: features['labels'] =
{}
, labels =
{}"
.
format
(
features
[
"labels"
],
labels
))
labels
=
features
.
pop
(
"labels"
)
model
=
self
.
_model_class
(
features
,
labels
,
hparams
,
mode
)
...
...
research/astronet/astronet/util/example_util.py
View file @
9d0f41b7
...
...
@@ -48,7 +48,8 @@ def get_feature(ex, name, kind=None, strict=True):
return
np
.
array
([])
# Feature exists, but it's empty.
if
kind
and
kind
!=
inferred_kind
:
raise
TypeError
(
"Requested %s, but Feature has %s"
%
(
kind
,
inferred_kind
))
raise
TypeError
(
"Requested {}, but Feature has {}"
.
format
(
kind
,
inferred_kind
))
return
np
.
array
(
getattr
(
ex
.
features
.
feature
[
name
],
inferred_kind
).
value
)
...
...
@@ -105,7 +106,8 @@ def set_feature(ex,
del
ex
.
features
.
feature
[
name
]
else
:
raise
ValueError
(
"Attempting to set duplicate feature with name: %s"
%
name
)
"Attempting to overwrite feature with name: {}. "
"Set allow_overwrite=True if this is desired."
.
format
(
name
))
if
not
kind
:
kind
=
_infer_kind
(
value
)
...
...
@@ -117,7 +119,7 @@ def set_feature(ex,
elif
kind
==
"int64_list"
:
value
=
[
int
(
v
)
for
v
in
value
]
else
:
raise
ValueError
(
"Unrecognized kind:
%s"
%
kind
)
raise
ValueError
(
"Unrecognized kind:
{}"
.
format
(
kind
)
)
getattr
(
ex
.
features
.
feature
[
name
],
kind
).
value
.
extend
(
value
)
...
...
research/astronet/light_curve_util/cc/python/postproc.py
View file @
9d0f41b7
...
...
@@ -24,7 +24,8 @@ def ValueErrorOnFalse(ok, *output_args):
"""Raises ValueError if not ok, otherwise returns the output arguments."""
n_outputs
=
len
(
output_args
)
if
n_outputs
<
2
:
raise
ValueError
(
"Expected 2 or more output_args. Got: %d"
%
n_outputs
)
raise
ValueError
(
"Expected 2 or more output_args. Got: {}"
.
format
(
n_outputs
))
if
not
ok
:
error
=
output_args
[
-
1
]
...
...
research/astronet/light_curve_util/kepler_io.py
View file @
9d0f41b7
...
...
@@ -119,7 +119,7 @@ def kepler_filenames(base_dir,
A list of filenames.
"""
# Pad the Kepler id with zeros to length 9.
kep_id
=
"
%.9d"
%
int
(
kep_id
)
kep_id
=
"
{:09d}"
.
format
(
int
(
kep_id
)
)
quarter_prefixes
,
cadence_suffix
=
((
LONG_CADENCE_QUARTER_PREFIXES
,
"llc"
)
if
long_cadence
else
...
...
@@ -135,12 +135,11 @@ def kepler_filenames(base_dir,
for
quarter
in
quarters
:
for
quarter_prefix
in
quarter_prefixes
[
quarter
]:
if
injected_group
:
base_name
=
"kplr%s-%s_INJECTED-%s_%s.fits"
%
(
kep_id
,
quarter_prefix
,
injected_group
,
cadence_suffix
)
base_name
=
"kplr{}-{}_INJECTED-{}_{}.fits"
.
format
(
kep_id
,
quarter_prefix
,
injected_group
,
cadence_suffix
)
else
:
base_name
=
"kplr
%s-%s_%s.fits"
%
(
kep_id
,
quarter_prefix
,
cadence_suffix
)
base_name
=
"kplr
{}-{}_{}.fits"
.
format
(
kep_id
,
quarter_prefix
,
cadence_suffix
)
filename
=
os
.
path
.
join
(
base_dir
,
base_name
)
# Not all stars have data for all quarters.
if
not
check_existence
or
gfile
.
Exists
(
filename
):
...
...
research/astronet/light_curve_util/kepler_io_test.py
View file @
9d0f41b7
...
...
@@ -122,15 +122,17 @@ class KeplerIoTest(absltest.TestCase):
filenames
=
kepler_io
.
kepler_filenames
(
self
.
data_dir
,
11442793
,
check_existence
=
True
)
expected_filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-%s_llc.fits"
)
%
q
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
self
.
assertItemsEqual
(
expected_filenames
,
filenames
)
def
testReadKeplerLightCurve
(
self
):
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-%s_llc.fits"
)
%
q
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
all_time
,
all_flux
=
kepler_io
.
read_kepler_light_curve
(
filenames
)
self
.
assertLen
(
all_time
,
3
)
...
...
@@ -148,8 +150,9 @@ class KeplerIoTest(absltest.TestCase):
def
testReadKeplerLightCurveScrambled
(
self
):
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-%s_llc.fits"
)
%
q
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
all_time
,
all_flux
=
kepler_io
.
read_kepler_light_curve
(
filenames
,
scramble_type
=
"SCR1"
)
...
...
@@ -170,8 +173,9 @@ class KeplerIoTest(absltest.TestCase):
def
testReadKeplerLightCurveScrambledInterpolateMissingTime
(
self
):
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-%s_llc.fits"
)
%
q
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
os
.
path
.
join
(
self
.
data_dir
,
"0114/011442793/kplr011442793-{}_llc.fits"
.
format
(
q
))
for
q
in
[
"2009350155506"
,
"2010009091648"
,
"2010174085026"
]
]
all_time
,
all_flux
=
kepler_io
.
read_kepler_light_curve
(
filenames
,
scramble_type
=
"SCR1"
,
interpolate_missing_time
=
True
)
...
...
research/astronet/light_curve_util/median_filter.py
View file @
9d0f41b7
...
...
@@ -51,35 +51,35 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
ValueError: If an argument has an inappropriate value.
"""
if
num_bins
<
2
:
raise
ValueError
(
"num_bins must be at least 2. Got:
%d"
%
num_bins
)
raise
ValueError
(
"num_bins must be at least 2. Got:
{}"
.
format
(
num_bins
)
)
# Validate the lengths of x and y.
x_len
=
len
(
x
)
if
x_len
<
2
:
raise
ValueError
(
"len(x) must be at least 2. Got:
%s"
%
x_len
)
raise
ValueError
(
"len(x) must be at least 2. Got:
{}"
.
format
(
x_len
)
)
if
x_len
!=
len
(
y
):
raise
ValueError
(
"len(x) (got:
%d
) must equal len(y) (got:
%d)"
%
(
x_len
,
len
(
y
)))
raise
ValueError
(
"len(x) (got:
{}
) must equal len(y) (got:
{})"
.
format
(
x_len
,
len
(
y
)))
# Validate x_min and x_max.
x_min
=
x_min
if
x_min
is
not
None
else
x
[
0
]
x_max
=
x_max
if
x_max
is
not
None
else
x
[
-
1
]
if
x_min
>=
x_max
:
raise
ValueError
(
"x_min (got:
%d
) must be less than x_max (got:
%d)"
%
(
x_min
,
x_max
))
raise
ValueError
(
"x_min (got:
{}
) must be less than x_max (got:
{})"
.
format
(
x_min
,
x_max
))
if
x_min
>
x
[
-
1
]:
raise
ValueError
(
"x_min (got:
%d
) must be less than or equal to the largest value of x "
"(got:
%d)"
%
(
x_min
,
x
[
-
1
]))
"x_min (got:
{}
) must be less than or equal to the largest value of x "
"(got:
{})"
.
format
(
x_min
,
x
[
-
1
]))
# Validate bin_width.
bin_width
=
bin_width
if
bin_width
is
not
None
else
(
x_max
-
x_min
)
/
num_bins
if
bin_width
<=
0
:
raise
ValueError
(
"bin_width must be positive. Got:
%d"
%
bin_width
)
raise
ValueError
(
"bin_width must be positive. Got:
{}"
.
format
(
bin_width
)
)
if
bin_width
>=
x_max
-
x_min
:
raise
ValueError
(
"bin_width (got:
%d
) must be less than x_max - x_min (got:
%d)"
%
(
bin_width
,
x_max
-
x_min
))
"bin_width (got:
{}
) must be less than x_max - x_min (got:
{})"
.
format
(
bin_width
,
x_max
-
x_min
))
bin_spacing
=
(
x_max
-
x_min
-
bin_width
)
/
(
num_bins
-
1
)
...
...
research/astronet/light_curve_util/util.py
View file @
9d0f41b7
...
...
@@ -287,8 +287,8 @@ def count_transit_points(time, event):
# Tiny periods or erroneous time values could make this loop take forever.
if
(
t_max
-
t_min
)
/
event
.
period
>
10
**
6
:
raise
ValueError
(
"Too many transits! Time range is [
%.2f, %.2f
] and period is
%.2e."
%
(
t_min
,
t_max
,
event
.
period
))
"Too many transits! Time range is [
{:.4f}, {:.4f}
] and period is
"
"{:.4e}."
.
format
(
t_min
,
t_max
,
event
.
period
))
# Make sure t0 is in [t_min, t_min + period).
t0
=
np
.
mod
(
event
.
t0
-
t_min
,
event
.
period
)
+
t_min
...
...
research/astronet/third_party/kepler_spline/kepler_spline.py
View file @
9d0f41b7
...
...
@@ -54,7 +54,8 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
"""
if
len
(
time
)
<
4
:
raise
InsufficientPointsError
(
"Cannot fit a spline on less than 4 points. Got %d points."
%
len
(
time
))
"Cannot fit a spline on less than 4 points. Got {} points."
.
format
(
len
(
time
)))
# Rescale time into [0, 1].
t_min
=
np
.
min
(
time
)
...
...
@@ -91,7 +92,7 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
# and we consider this a fatal error.
raise
InsufficientPointsError
(
"Cannot fit a spline on less than 4 points. After removing "
"outliers, got
%d
points."
%
np
.
sum
(
mask
))
"outliers, got
{}
points."
.
format
(
np
.
sum
(
mask
))
)
try
:
with
warnings
.
catch_warnings
():
...
...
@@ -106,9 +107,9 @@ def kepler_spline(time, flux, bkspace=1.5, maxiter=5, outlier_cut=3):
spline
=
curve
.
value
(
time
)[
0
]
except
(
IndexError
,
TypeError
)
as
e
:
raise
SplineError
(
"Fitting spline failed with error: '
%s
'. This might be caused by the "
"Fitting spline failed with error: '
{}
'. This might be caused by the "
"breakpoint spacing being too small, and/or there being insufficient "
"points to fit the spline in one of the intervals."
%
e
)
"points to fit the spline in one of the intervals."
.
format
(
e
)
)
return
spline
,
mask
...
...
@@ -227,7 +228,7 @@ def choose_kepler_spline(all_time,
# It's expected to get a SplineError occasionally for small values of
# bkspace. Skip this bkspace.
if
verbose
:
warnings
.
warn
(
"Bad bkspace
%.4f: %s"
%
(
bkspace
,
e
))
warnings
.
warn
(
"Bad bkspace
{}: {}"
.
format
(
bkspace
,
e
))
metadata
.
bad_bkspaces
.
append
(
bkspace
)
bad_bkspace
=
True
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment