Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b200ebb
Commit
8b200ebb
authored
Feb 28, 2018
by
Christopher Shallue
Browse files
Initial AstroNet commit.
parent
0b74e527
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1000 additions
and
0 deletions
+1000
-0
research/astronet/astronet/util/BUILD
research/astronet/astronet/util/BUILD
+44
-0
research/astronet/astronet/util/__init__.py
research/astronet/astronet/util/__init__.py
+14
-0
research/astronet/astronet/util/config_util.py
research/astronet/astronet/util/config_util.py
+120
-0
research/astronet/astronet/util/config_util_test.py
research/astronet/astronet/util/config_util_test.py
+62
-0
research/astronet/astronet/util/configdict.py
research/astronet/astronet/util/configdict.py
+64
-0
research/astronet/astronet/util/configdict_test.py
research/astronet/astronet/util/configdict_test.py
+180
-0
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+318
-0
research/astronet/docs/kep90-all.png
research/astronet/docs/kep90-all.png
+0
-0
research/astronet/docs/kep90-q4-normalized.png
research/astronet/docs/kep90-q4-normalized.png
+0
-0
research/astronet/docs/kep90-q4-raw.png
research/astronet/docs/kep90-q4-raw.png
+0
-0
research/astronet/docs/kep90-q4-spline.png
research/astronet/docs/kep90-q4-spline.png
+0
-0
research/astronet/docs/kep90h-localglobal.png
research/astronet/docs/kep90h-localglobal.png
+0
-0
research/astronet/docs/kep90i-localglobal.png
research/astronet/docs/kep90i-localglobal.png
+0
-0
research/astronet/docs/kepler-943-transits.png
research/astronet/docs/kepler-943-transits.png
+0
-0
research/astronet/docs/kepler-943.png
research/astronet/docs/kepler-943.png
+0
-0
research/astronet/docs/tensorboard.png
research/astronet/docs/tensorboard.png
+0
-0
research/astronet/docs/transit.gif
research/astronet/docs/transit.gif
+0
-0
research/astronet/light_curve_util/BUILD
research/astronet/light_curve_util/BUILD
+65
-0
research/astronet/light_curve_util/__init__.py
research/astronet/light_curve_util/__init__.py
+14
-0
research/astronet/light_curve_util/cc/BUILD
research/astronet/light_curve_util/cc/BUILD
+119
-0
No files found.
research/astronet/astronet/util/BUILD
0 → 100644
View file @
8b200ebb
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"configdict"
,
srcs
=
[
"configdict.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
],
)
py_test
(
name
=
"configdict_test"
,
size
=
"small"
,
srcs
=
[
"configdict_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":configdict"
],
)
py_library
(
name
=
"config_util"
,
srcs
=
[
"config_util.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"config_util_test"
,
size
=
"small"
,
srcs
=
[
"config_util_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":config_util"
],
)
py_library
(
name
=
"estimator_util"
,
srcs
=
[
"estimator_util.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//astronet/ops:dataset_ops"
,
"//astronet/ops:metrics"
,
"//astronet/ops:training"
,
],
)
research/astronet/astronet/util/__init__.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astronet/util/config_util.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for configurations."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os.path
import
tensorflow
as
tf
def
parse_json
(
json_string_or_file
):
"""Parses values from a JSON string or JSON file.
This function is useful for command line flags containing configuration
overrides. Using this function, the flag can be passed either as a JSON string
(e.g. '{"learning_rate": 1.0}') or the path to a JSON configuration file.
Args:
json_string_or_file: A JSON serialized string OR the path to a JSON file.
Returns:
A dictionary; the parsed JSON.
Raises:
ValueError: If the JSON could not be parsed.
"""
# First, attempt to parse the string as a JSON dict.
try
:
json_dict
=
json
.
loads
(
json_string_or_file
)
except
ValueError
as
literal_json_parsing_error
:
try
:
# Otherwise, try to use it as a path to a JSON file.
with
tf
.
gfile
.
Open
(
json_string_or_file
)
as
f
:
json_dict
=
json
.
load
(
f
)
except
ValueError
as
json_file_parsing_error
:
raise
ValueError
(
"Unable to parse the content of the json file %s. "
"Parsing error: %s."
%
(
json_string_or_file
,
json_file_parsing_error
.
message
))
except
tf
.
gfile
.
FileError
:
message
=
(
"Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.
\n
"
"JSON parsing error: %s
\n\n
Input parameter:
\n
%s."
%
(
literal_json_parsing_error
.
message
,
json_string_or_file
))
raise
ValueError
(
message
)
return
json_dict
def
log_and_save_config
(
config
,
output_dir
):
"""Logs and writes a JSON-serializable configuration object.
Args:
config: A JSON-serializable object.
output_dir: Destination directory.
"""
if
hasattr
(
config
,
"to_json"
)
and
callable
(
config
.
to_json
):
config_json
=
config
.
to_json
(
indent
=
2
)
else
:
config_json
=
json
.
dumps
(
config
,
indent
=
2
)
tf
.
logging
.
info
(
"config: %s"
,
config_json
)
tf
.
gfile
.
MakeDirs
(
output_dir
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
output_dir
,
"config.json"
),
"w"
)
as
f
:
f
.
write
(
config_json
)
def
unflatten
(
flat_config
):
"""Transforms a flat configuration dictionary into a nested dictionary.
Example:
{
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.d.f": 4,
}
would be transformed to:
{
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3,
"f": 4,
}
}
}
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
"""
config
=
{}
for
path
,
value
in
flat_config
.
iteritems
():
path
=
path
.
split
(
"."
)
final_key
=
path
.
pop
()
nested_config
=
config
for
key
in
path
:
nested_config
=
nested_config
.
setdefault
(
key
,
{})
nested_config
[
final_key
]
=
value
return
config
research/astronet/astronet/util/config_util_test.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
astronet.util
import
config_util
class
ConfigUtilTest
(
tf
.
test
.
TestCase
):
def
testUnflatten
(
self
):
# Empty dict.
self
.
assertDictEqual
(
config_util
.
unflatten
({}),
{})
# Already flat dict.
self
.
assertDictEqual
(
config_util
.
unflatten
({
"a"
:
1
,
"b"
:
2
}),
{
"a"
:
1
,
"b"
:
2
})
# Nested dict.
self
.
assertDictEqual
(
config_util
.
unflatten
({
"a"
:
1
,
"b.c"
:
2
,
"b.d.e"
:
3
,
"b.d.f"
:
4
,
}),
{
"a"
:
1
,
"b"
:
{
"c"
:
2
,
"d"
:
{
"e"
:
3
,
"f"
:
4
,
}
}
})
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/astronet/astronet/util/configdict.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration container for TensorFlow models.
A ConfigDict is simply a dict whose values can be accessed via both dot syntax
(config.key) and dict syntax (config['key']).
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
_maybe_convert_dict
(
value
):
if
isinstance
(
value
,
dict
):
return
ConfigDict
(
value
)
return
value
class
ConfigDict
(
dict
):
"""Configuration container class."""
def
__init__
(
self
,
initial_dictionary
=
None
):
"""Creates an instance of ConfigDict.
Args:
initial_dictionary: Optional dictionary or ConfigDict containing initial
parameters.
"""
if
initial_dictionary
:
for
field
,
value
in
initial_dictionary
.
iteritems
():
initial_dictionary
[
field
]
=
_maybe_convert_dict
(
value
)
super
(
ConfigDict
,
self
).
__init__
(
initial_dictionary
)
def
__setattr__
(
self
,
attribute
,
value
):
self
[
attribute
]
=
_maybe_convert_dict
(
value
)
def
__getattr__
(
self
,
attribute
):
try
:
return
self
[
attribute
]
except
KeyError
as
e
:
raise
AttributeError
(
e
)
def
__delattr__
(
self
,
attribute
):
try
:
del
self
[
attribute
]
except
KeyError
as
e
:
raise
AttributeError
(
e
)
def
__setitem__
(
self
,
key
,
value
):
super
(
ConfigDict
,
self
).
__setitem__
(
key
,
_maybe_convert_dict
(
value
))
research/astronet/astronet/util/configdict_test.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.configdict."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
absltest
from
astronet.util
import
configdict
class
ConfigDictTest
(
absltest
.
TestCase
):
def
setUp
(
self
):
super
(
ConfigDictTest
,
self
).
setUp
()
self
.
_config
=
configdict
.
ConfigDict
({
"int"
:
1
,
"float"
:
2.0
,
"bool"
:
True
,
"str"
:
"hello"
,
"nested"
:
{
"int"
:
3
,
},
"double_nested"
:
{
"a"
:
{
"int"
:
3
,
},
"b"
:
{
"float"
:
4.0
,
}
}
})
def
testAccess
(
self
):
# Simple types.
self
.
assertEqual
(
1
,
self
.
_config
.
int
)
self
.
assertEqual
(
1
,
self
.
_config
[
"int"
])
self
.
assertEqual
(
2.0
,
self
.
_config
.
float
)
self
.
assertEqual
(
2.0
,
self
.
_config
[
"float"
])
self
.
assertTrue
(
self
.
_config
.
bool
)
self
.
assertTrue
(
self
.
_config
[
"bool"
])
self
.
assertEqual
(
"hello"
,
self
.
_config
.
str
)
self
.
assertEqual
(
"hello"
,
self
.
_config
[
"str"
])
# Single nested config.
self
.
assertEqual
(
3
,
self
.
_config
.
nested
.
int
)
self
.
assertEqual
(
3
,
self
.
_config
[
"nested"
].
int
)
self
.
assertEqual
(
3
,
self
.
_config
.
nested
[
"int"
])
self
.
assertEqual
(
3
,
self
.
_config
[
"nested"
][
"int"
])
# Double nested config.
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
].
a
.
int
)
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
][
"a"
].
int
)
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
].
a
[
"int"
])
self
.
assertEqual
(
3
,
self
.
_config
[
"double_nested"
][
"a"
][
"int"
])
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
.
b
.
float
)
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
[
"b"
].
float
)
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
.
b
[
"float"
])
self
.
assertEqual
(
4.0
,
self
.
_config
.
double_nested
[
"b"
][
"float"
])
# Nonexistent parameters.
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
nonexistent
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"nonexistent"
]
def
testSetAttribut
(
self
):
# Overwrite existing simple type.
self
.
_config
.
int
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
int
)
# Overwrite existing nested simple type.
self
.
_config
.
nested
.
int
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
nested
.
int
)
# Overwrite existing nested config.
self
.
_config
.
double_nested
.
a
=
{
"float"
:
50.0
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
a
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
50.0
,
self
.
_config
.
double_nested
.
a
.
float
)
self
.
assertNotIn
(
"int"
,
self
.
_config
.
double_nested
.
a
)
# Set new simple type.
self
.
_config
.
int_2
=
10
self
.
assertEqual
(
10
,
self
.
_config
.
int_2
)
# Set new nested simple type.
self
.
_config
.
nested
.
int_2
=
20
self
.
assertEqual
(
20
,
self
.
_config
.
nested
.
int_2
)
# Set new nested config.
self
.
_config
.
double_nested
.
c
=
{
"int"
:
30
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
c
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
30
,
self
.
_config
.
double_nested
.
c
.
int
)
def
testSetItem
(
self
):
# Overwrite existing simple type.
self
.
_config
[
"int"
]
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
int
)
# Overwrite existing nested simple type.
self
.
_config
[
"nested"
].
int
=
40
self
.
assertEqual
(
40
,
self
.
_config
.
nested
.
int
)
self
.
_config
.
nested
[
"int"
]
=
50
self
.
assertEqual
(
50
,
self
.
_config
.
nested
.
int
)
# Overwrite existing nested config.
self
.
_config
.
double_nested
[
"a"
]
=
{
"float"
:
50.0
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
a
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
50.0
,
self
.
_config
.
double_nested
.
a
.
float
)
self
.
assertNotIn
(
"int"
,
self
.
_config
.
double_nested
.
a
)
# Set new simple type.
self
.
_config
[
"int_2"
]
=
10
self
.
assertEqual
(
10
,
self
.
_config
.
int_2
)
# Set new nested simple type.
self
.
_config
.
nested
[
"int_2"
]
=
20
self
.
assertEqual
(
20
,
self
.
_config
.
nested
.
int_2
)
self
.
_config
.
nested
[
"int_3"
]
=
30
self
.
assertEqual
(
30
,
self
.
_config
.
nested
.
int_3
)
# Set new nested config.
self
.
_config
.
double_nested
[
"c"
]
=
{
"int"
:
30
}
self
.
assertIsInstance
(
self
.
_config
.
double_nested
.
c
,
configdict
.
ConfigDict
)
self
.
assertEqual
(
30
,
self
.
_config
.
double_nested
.
c
.
int
)
def
testDelete
(
self
):
# Simple types.
self
.
assertEqual
(
1
,
self
.
_config
.
int
)
del
self
.
_config
.
int
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
int
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"int"
]
self
.
assertEqual
(
2.0
,
self
.
_config
[
"float"
])
del
self
.
_config
[
"float"
]
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
float
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"float"
]
# Nested config.
self
.
assertEqual
(
3
,
self
.
_config
.
nested
.
int
)
del
self
.
_config
.
nested
with
self
.
assertRaises
(
AttributeError
):
_
=
self
.
_config
.
nested
with
self
.
assertRaises
(
KeyError
):
_
=
self
.
_config
[
"nested"
]
if
__name__
==
"__main__"
:
absltest
.
main
()
research/astronet/astronet/util/estimator_util.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training models with the TensorFlow Estimator API."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
tensorflow
as
tf
from
astronet.ops
import
dataset_ops
from
astronet.ops
import
metrics
from
astronet.ops
import
training
def
create_input_fn
(
file_pattern
,
input_config
,
mode
,
shuffle_values_buffer
=
0
,
repeat
=
1
):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
A callable that builds an input pipeline and returns (features, labels).
"""
include_labels
=
(
mode
in
[
tf
.
estimator
.
ModeKeys
.
TRAIN
,
tf
.
estimator
.
ModeKeys
.
EVAL
])
reverse_time_series_prob
=
0.5
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
0
shuffle_filenames
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
def
input_fn
(
config
,
params
):
"""Builds an input pipeline that reads a dataset from TFRecord files."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type.
use_tpu
=
isinstance
(
config
,
tf
.
contrib
.
tpu
.
RunConfig
)
dataset
=
dataset_ops
.
build_dataset
(
file_pattern
=
file_pattern
,
input_config
=
input_config
,
batch_size
=
params
[
"batch_size"
],
include_labels
=
include_labels
,
reverse_time_series_prob
=
reverse_time_series_prob
,
shuffle_filenames
=
shuffle_filenames
,
shuffle_values_buffer
=
shuffle_values_buffer
,
repeat
=
repeat
,
use_tpu
=
use_tpu
)
# We must use an initializable iterator, rather than a one-shot iterator,
# because the input pipeline contains a stateful table that requires
# initialization. We add the initializer to the TABLE_INITIALIZERS
# collection to ensure it is run during initialization.
iterator
=
dataset
.
make_initializable_iterator
()
tf
.
add_to_collection
(
tf
.
GraphKeys
.
TABLE_INITIALIZERS
,
iterator
.
initializer
)
inputs
=
iterator
.
get_next
()
return
inputs
,
inputs
.
pop
(
"labels"
,
None
)
return
input_fn
def
create_model_fn
(
model_class
,
hparams
,
use_tpu
=
False
):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
hparams
=
copy
.
deepcopy
(
hparams
)
def
model_fn
(
features
,
labels
,
mode
,
params
):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
# For TPUEstimator, params contains the batch size per TPU core.
if
"batch_size"
in
params
:
hparams
.
batch_size
=
params
[
"batch_size"
]
model
=
model_class
(
features
,
labels
,
hparams
,
mode
)
model
.
build
()
# Possibly create train_op.
train_op
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
learning_rate
=
training
.
create_learning_rate
(
hparams
,
model
.
global_step
)
optimizer
=
training
.
create_optimizer
(
hparams
,
learning_rate
,
use_tpu
)
train_op
=
training
.
create_train_op
(
model
,
optimizer
)
# Possibly create evaluation metrics.
eval_metrics
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
eval_metrics
=
(
metrics
.
create_metric_fn
(
model
)
if
use_tpu
else
metrics
.
create_metrics
(
model
))
if
use_tpu
:
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
model
.
predictions
,
loss
=
model
.
total_loss
,
train_op
=
train_op
,
eval_metrics
=
eval_metrics
)
else
:
estimator
=
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
model
.
predictions
,
loss
=
model
.
total_loss
,
train_op
=
train_op
,
eval_metric_ops
=
eval_metrics
)
return
estimator
return
model_fn
def
create_estimator
(
model_class
,
hparams
,
run_config
=
None
,
model_dir
=
None
,
eval_batch_size
=
None
):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if
run_config
is
None
:
run_config
=
tf
.
estimator
.
RunConfig
()
else
:
run_config
=
copy
.
deepcopy
(
run_config
)
if
not
model_dir
and
not
run_config
.
model_dir
:
raise
ValueError
(
"model_dir must be passed explicitly or specified in run_config"
)
use_tpu
=
isinstance
(
run_config
,
tf
.
contrib
.
tpu
.
RunConfig
)
model_fn
=
create_model_fn
(
model_class
,
hparams
,
use_tpu
)
if
use_tpu
:
eval_batch_size
=
eval_batch_size
or
hparams
.
batch_size
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
train_batch_size
=
hparams
.
batch_size
,
eval_batch_size
=
eval_batch_size
)
else
:
if
eval_batch_size
is
not
None
:
raise
ValueError
(
"eval_batch_size can only be specified for TPU."
)
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
params
=
{
"batch_size"
:
hparams
.
batch_size
})
return
estimator
def
evaluate
(
estimator
,
input_fn
,
eval_steps
=
None
,
eval_name
=
"val"
):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
values
=
{}
# Default return value if evaluation fails.
latest_checkpoint
=
tf
.
train
.
latest_checkpoint
(
estimator
.
model_dir
)
if
not
latest_checkpoint
:
# This is expected if the training job has not yet saved a checkpoint.
return
values
tf
.
logging
.
info
(
"Starting evaluation on checkpoint %s"
,
latest_checkpoint
)
try
:
values
=
estimator
.
evaluate
(
input_fn
,
steps
=
eval_steps
,
name
=
eval_name
)
except
tf
.
errors
.
NotFoundError
:
# Expected under some conditions, e.g. TPU worker does not finish
# initializing until long after the CPU job tells it to start evaluating
# and the checkpoint file is deleted already.
tf
.
logging
.
info
(
"Checkpoint %s no longer exists, skipping evaluation"
,
latest_checkpoint
)
return
values
def
continuous_eval
(
estimator
,
input_fn
,
train_steps
=
None
,
eval_steps
=
None
,
eval_name
=
"val"
):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training. If None, this
function will run forever.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for
_
in
tf
.
contrib
.
training
.
checkpoints_iterator
(
estimator
.
model_dir
):
values
=
evaluate
(
estimator
,
input_fn
,
eval_steps
,
eval_name
)
yield
values
global_step
=
values
.
get
(
"global_step"
,
0
)
if
train_steps
and
global_step
>=
train_steps
:
break
def
continuous_train_and_eval
(
estimator
,
train_input_fn
,
eval_input_fn
,
local_eval_frequency
=
None
,
train_hooks
=
None
,
train_steps
=
None
,
eval_steps
=
None
,
eval_name
=
"val"
):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while
True
:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
values
=
evaluate
(
estimator
,
eval_input_fn
,
eval_steps
,
eval_name
)
yield
values
global_step
=
values
.
get
(
"global_step"
,
0
)
if
train_steps
and
global_step
>=
train_steps
:
break
# Decide how many steps before the next evaluation.
steps
=
local_eval_frequency
if
train_steps
:
remaining_steps
=
train_steps
-
global_step
steps
=
min
(
steps
,
remaining_steps
)
if
steps
else
remaining_steps
tf
.
logging
.
info
(
"Starting training at global step %d"
,
global_step
)
estimator
.
train
(
train_input_fn
,
hooks
=
train_hooks
,
steps
=
steps
)
research/astronet/docs/kep90-all.png
0 → 100644
View file @
8b200ebb
65.8 KB
research/astronet/docs/kep90-q4-normalized.png
0 → 100644
View file @
8b200ebb
52.2 KB
research/astronet/docs/kep90-q4-raw.png
0 → 100644
View file @
8b200ebb
50.5 KB
research/astronet/docs/kep90-q4-spline.png
0 → 100644
View file @
8b200ebb
73.1 KB
research/astronet/docs/kep90h-localglobal.png
0 → 100644
View file @
8b200ebb
30.7 KB
research/astronet/docs/kep90i-localglobal.png
0 → 100644
View file @
8b200ebb
66.1 KB
research/astronet/docs/kepler-943-transits.png
0 → 100644
View file @
8b200ebb
124 KB
research/astronet/docs/kepler-943.png
0 → 100644
View file @
8b200ebb
156 KB
research/astronet/docs/tensorboard.png
0 → 100644
View file @
8b200ebb
98.3 KB
research/astronet/docs/transit.gif
0 → 100644
View file @
8b200ebb
6.52 MB
research/astronet/light_curve_util/BUILD
0 → 100644
View file @
8b200ebb
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"kepler_io"
,
srcs
=
[
"kepler_io.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"kepler_io_test"
,
size
=
"small"
,
srcs
=
[
"kepler_io_test.py"
],
data
=
glob
([
"test_data/0114/011442793/kplr*.fits"
,
]),
srcs_version
=
"PY2AND3"
,
deps
=
[
":kepler_io"
],
)
py_library
(
name
=
"median_filter"
,
srcs
=
[
"median_filter.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"median_filter_test"
,
size
=
"small"
,
srcs
=
[
"median_filter_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":median_filter"
],
)
py_library
(
name
=
"periodic_event"
,
srcs
=
[
"periodic_event.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"periodic_event_test"
,
size
=
"small"
,
srcs
=
[
"periodic_event_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":periodic_event"
],
)
py_library
(
name
=
"util"
,
srcs
=
[
"util.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"util_test"
,
size
=
"small"
,
srcs
=
[
"util_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":periodic_event"
,
":util"
,
],
)
research/astronet/light_curve_util/__init__.py
0 → 100644
View file @
8b200ebb
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/light_curve_util/cc/BUILD
0 → 100644
View file @
8b200ebb
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
cc_library
(
name
=
"median"
,
hdrs
=
[
"median.h"
],
)
cc_test
(
name
=
"median_test"
,
size
=
"small"
,
srcs
=
[
"median_test.cc"
,
],
deps
=
[
":median"
,
"@com_google_googletest//:gtest_main"
,
],
)
cc_library
(
name
=
"median_filter"
,
srcs
=
[
"median_filter.cc"
],
hdrs
=
[
"median_filter.h"
],
deps
=
[
":median"
,
"@com_google_absl//absl/strings"
,
],
)
cc_test
(
name
=
"median_filter_test"
,
size
=
"small"
,
srcs
=
[
"median_filter_test.cc"
,
],
deps
=
[
":median_filter"
,
":test_util"
,
"@com_google_googletest//:gtest_main"
,
],
)
cc_library
(
name
=
"phase_fold"
,
srcs
=
[
"phase_fold.cc"
],
hdrs
=
[
"phase_fold.h"
],
deps
=
[
"@com_google_absl//absl/strings"
],
)
cc_test
(
name
=
"phase_fold_test"
,
size
=
"small"
,
srcs
=
[
"phase_fold_test.cc"
,
],
deps
=
[
":phase_fold"
,
":test_util"
,
"@com_google_googletest//:gtest_main"
,
],
)
cc_library
(
name
=
"normalize"
,
srcs
=
[
"normalize.cc"
],
hdrs
=
[
"normalize.h"
],
deps
=
[
":median"
,
"@com_google_absl//absl/strings"
,
],
)
cc_test
(
name
=
"normalize_test"
,
size
=
"small"
,
srcs
=
[
"normalize_test.cc"
,
],
deps
=
[
":normalize"
,
":test_util"
,
"@com_google_googletest//:gtest_main"
,
],
)
cc_library
(
name
=
"view_generator"
,
srcs
=
[
"view_generator.cc"
],
hdrs
=
[
"view_generator.h"
],
deps
=
[
":median_filter"
,
":normalize"
,
":phase_fold"
,
"@com_google_absl//absl/memory"
,
],
)
cc_test
(
name
=
"view_generator_test"
,
size
=
"small"
,
srcs
=
[
"view_generator_test.cc"
,
],
deps
=
[
":test_util"
,
":view_generator"
,
"@com_google_googletest//:gtest_main"
,
],
)
cc_library
(
name
=
"test_util"
,
hdrs
=
[
"test_util.h"
],
deps
=
[
"@com_google_googletest//:gtest"
,
],
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment