Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
252e2d2e
Commit
252e2d2e
authored
Sep 21, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Update formatting.
PiperOrigin-RevId: 213963765
parent
9d0f41b7
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
56 additions
and
65 deletions
+56
-65
research/astronet/astronet/astro_cnn_model/astro_cnn_model_test.py
...astronet/astronet/astro_cnn_model/astro_cnn_model_test.py
+1
-2
research/astronet/astronet/astro_fc_model/astro_fc_model_test.py
...h/astronet/astronet/astro_fc_model/astro_fc_model_test.py
+1
-2
research/astronet/astronet/astro_model/astro_model.py
research/astronet/astronet/astro_model/astro_model.py
+10
-9
research/astronet/astronet/astro_model/astro_model_test.py
research/astronet/astronet/astro_model/astro_model_test.py
+1
-2
research/astronet/astronet/data/generate_input_records.py
research/astronet/astronet/data/generate_input_records.py
+4
-5
research/astronet/astronet/data/preprocess.py
research/astronet/astronet/data/preprocess.py
+4
-4
research/astronet/astronet/models.py
research/astronet/astronet/models.py
+1
-1
research/astronet/astronet/ops/dataset_ops.py
research/astronet/astronet/ops/dataset_ops.py
+6
-6
research/astronet/astronet/ops/input_ops.py
research/astronet/astronet/ops/input_ops.py
+2
-3
research/astronet/astronet/ops/input_ops_test.py
research/astronet/astronet/ops/input_ops_test.py
+2
-2
research/astronet/astronet/ops/testing.py
research/astronet/astronet/ops/testing.py
+4
-9
research/astronet/astronet/ops/training.py
research/astronet/astronet/ops/training.py
+1
-1
research/astronet/astronet/util/config_util.py
research/astronet/astronet/util/config_util.py
+1
-1
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+3
-3
research/astronet/astronet/util/example_util.py
research/astronet/astronet/util/example_util.py
+2
-2
research/astronet/light_curve_util/cc/python/median_filter_test.py
...astronet/light_curve_util/cc/python/median_filter_test.py
+1
-1
research/astronet/light_curve_util/cc/python/phase_fold_test.py
...ch/astronet/light_curve_util/cc/python/phase_fold_test.py
+1
-1
research/astronet/light_curve_util/cc/python/view_generator_test.py
...stronet/light_curve_util/cc/python/view_generator_test.py
+1
-1
research/astronet/light_curve_util/kepler_io.py
research/astronet/light_curve_util/kepler_io.py
+5
-5
research/astronet/light_curve_util/median_filter.py
research/astronet/light_curve_util/median_filter.py
+5
-5
No files found.
research/astronet/astronet/astro_cnn_model/astro_cnn_model_test.py
View file @
252e2d2e
...
...
@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if
isinstance
(
tensor_or_array
,
(
np
.
ndarray
,
np
.
generic
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
)
...
...
research/astronet/astronet/astro_fc_model/astro_fc_model_test.py
View file @
252e2d2e
...
...
@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if
isinstance
(
tensor_or_array
,
(
np
.
ndarray
,
np
.
generic
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
)
...
...
research/astronet/astronet/astro_model/astro_model.py
View file @
252e2d2e
...
...
@@ -73,17 +73,19 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves."""
def
__init__
(
self
,
features
,
labels
,
hparams
,
mode
):
"""Basic setup. The actual TensorFlow graph is constructed in build().
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
...
...
@@ -201,10 +203,9 @@ class AstroModel(object):
if
len
(
hidden_layers
)
==
1
:
pre_logits_concat
=
hidden_layers
[
0
][
1
]
else
:
pre_logits_concat
=
tf
.
concat
(
[
layer
[
1
]
for
layer
in
hidden_layers
],
axis
=
1
,
name
=
"pre_logits_concat"
)
pre_logits_concat
=
tf
.
concat
([
layer
[
1
]
for
layer
in
hidden_layers
],
axis
=
1
,
name
=
"pre_logits_concat"
)
net
=
pre_logits_concat
with
tf
.
variable_scope
(
"pre_logits_hidden"
):
...
...
research/astronet/astronet/astro_model/astro_model_test.py
View file @
252e2d2e
...
...
@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if
isinstance
(
tensor_or_array
,
(
np
.
ndarray
,
np
.
generic
)):
self
.
assertAllEqual
(
shape
,
tensor_or_array
.
shape
)
...
...
research/astronet/astronet/data/generate_input_records.py
View file @
252e2d2e
...
...
@@ -88,7 +88,6 @@ import tensorflow as tf
from
astronet.data
import
preprocess
parser
=
argparse
.
ArgumentParser
()
_DR24_TCE_URL
=
(
"https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
...
...
@@ -225,10 +224,10 @@ def main(argv):
file_shards
.
append
((
train_tces
[
start
:
end
],
filename
))
# Validation and test sets each have a single shard.
file_shards
.
append
((
val_tces
,
os
.
path
.
join
(
FLAGS
.
output_dir
,
"val-00000-of-00001"
)))
file_shards
.
append
((
test_tces
,
os
.
path
.
join
(
FLAGS
.
output_dir
,
"test-00000-of-00001"
)))
file_shards
.
append
((
val_tces
,
os
.
path
.
join
(
FLAGS
.
output_dir
,
"val-00000-of-00001"
)))
file_shards
.
append
((
test_tces
,
os
.
path
.
join
(
FLAGS
.
output_dir
,
"test-00000-of-00001"
)))
num_file_shards
=
len
(
file_shards
)
# Launch subprocesses for the file shards.
...
...
research/astronet/astronet/data/preprocess.py
View file @
252e2d2e
...
...
@@ -34,7 +34,7 @@ def read_light_curve(kepid, kepler_data_dir):
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
kepler_io.kepler_filenames().
Returns:
all_time: A list of numpy arrays; the time values of the raw light curve.
...
...
@@ -59,7 +59,7 @@ def process_light_curve(all_time, all_flux):
Args:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_time.
Returns:
time: 1D NumPy array; the time values of the light curve.
...
...
@@ -192,7 +192,7 @@ def local_view(time,
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0).
event is assumed to be centered at 0).
Returns:
1D NumPy array of size num_bins containing the median flux values of
...
...
@@ -214,7 +214,7 @@ def generate_example_for_tce(time, flux, tce):
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output.
'tce_time0bk'. Additional items are included as features in the output.
Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all
...
...
research/astronet/astronet/models.py
View file @
252e2d2e
...
...
@@ -57,7 +57,7 @@ def get_model_config(model_name, config_name):
Args:
model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's
configurations module.
configurations module.
Returns:
model_class: The requested model class.
...
...
research/astronet/astronet/ops/dataset_ops.py
View file @
252e2d2e
...
...
@@ -142,19 +142,19 @@ def build_dataset(file_pattern,
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs.
epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely.
will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU.
Raises:
...
...
research/astronet/astronet/ops/input_ops.py
View file @
252e2d2e
...
...
@@ -27,11 +27,10 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape
[batch_size, length].
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
...
...
research/astronet/astronet/ops/input_ops_test.py
View file @
252e2d2e
...
...
@@ -31,9 +31,9 @@ class InputOpsTest(tf.test.TestCase):
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
input_ops.build_feature_placeholders().
"""
actual_shapes
=
{}
for
feature_type
in
features
:
...
...
research/astronet/astronet/ops/testing.py
View file @
252e2d2e
...
...
@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features
=
{}
features
[
"time_series_features"
]
=
{
name
:
np
.
random
.
random
([
batch_size
,
spec
[
"length"
]])
for
name
,
spec
in
feature_spec
.
items
()
if
spec
[
"is_time_series"
]
}
features
[
"aux_features"
]
=
{
name
:
np
.
random
.
random
([
batch_size
,
spec
[
"length"
]])
for
name
,
spec
in
feature_spec
.
items
()
if
not
spec
[
"is_time_series"
]
}
features
=
{
"time_series_features"
:
{},
"aux_features"
:
{}}
for
name
,
spec
in
feature_spec
.
items
():
ftype
=
"time_series_features"
if
spec
[
"is_time_series"
]
else
"aux_features"
features
[
ftype
][
name
]
=
np
.
random
.
random
([
batch_size
,
spec
[
"length"
]])
return
features
...
...
research/astronet/astronet/ops/training.py
View file @
252e2d2e
...
...
@@ -51,7 +51,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False):
hparams: ConfigDict containing the optimizer configuration.
learning_rate: A Python float or a scalar Tensor.
use_tpu: If True, the returned optimizer is wrapped in a
CrossShardOptimizer.
CrossShardOptimizer.
Returns:
A TensorFlow optimizer.
...
...
research/astronet/astronet/util/config_util.py
View file @
252e2d2e
...
...
@@ -105,7 +105,7 @@ def unflatten(flat_config):
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
...
...
research/astronet/astronet/util/estimator_util.py
View file @
252e2d2e
...
...
@@ -203,10 +203,10 @@ def create_estimator(model_class,
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
...
...
research/astronet/astronet/util/example_util.py
View file @
252e2d2e
...
...
@@ -28,7 +28,7 @@ def get_feature(ex, name, kind=None, strict=True):
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
...
...
@@ -93,7 +93,7 @@ def set_feature(ex,
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
...
...
research/astronet/light_curve_util/cc/python/median_filter_test.py
View file @
252e2d2e
...
...
@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
absltest
.
main
()
research/astronet/light_curve_util/cc/python/phase_fold_test.py
View file @
252e2d2e
...
...
@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np
.
testing
.
assert_almost_equal
(
folded_flux
,
expected_flux
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
absltest
.
main
()
research/astronet/light_curve_util/cc/python/view_generator_test.py
View file @
252e2d2e
...
...
@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np
.
testing
.
assert_almost_equal
(
result
,
expected
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
absltest
.
main
()
research/astronet/light_curve_util/kepler_io.py
View file @
252e2d2e
...
...
@@ -105,15 +105,15 @@ def kepler_filenames(base_dir,
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
...
...
research/astronet/light_curve_util/median_filter.py
View file @
252e2d2e
...
...
@@ -32,16 +32,16 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment