Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
763663de
Commit
763663de
authored
Oct 16, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Project import generated by Copybara.
PiperOrigin-RevId: 217341274
parent
ca2db9bd
Changes
21
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2464 additions
and
84 deletions
+2464
-84
research/astronet/README.md
research/astronet/README.md
+5
-1
research/astronet/astronet/util/config_util.py
research/astronet/astronet/util/config_util.py
+9
-4
research/astronet/astrowavenet/BUILD
research/astronet/astrowavenet/BUILD
+21
-5
research/astronet/astrowavenet/README.md
research/astronet/astrowavenet/README.md
+44
-0
research/astronet/astrowavenet/__init__.py
research/astronet/astrowavenet/__init__.py
+14
-0
research/astronet/astrowavenet/astrowavenet_model.py
research/astronet/astrowavenet/astrowavenet_model.py
+56
-44
research/astronet/astrowavenet/astrowavenet_model_test.py
research/astronet/astrowavenet/astrowavenet_model_test.py
+631
-0
research/astronet/astrowavenet/data/BUILD
research/astronet/astrowavenet/data/BUILD
+42
-0
research/astronet/astrowavenet/data/__init__.py
research/astronet/astrowavenet/data/__init__.py
+14
-0
research/astronet/astrowavenet/data/base.py
research/astronet/astrowavenet/data/base.py
+240
-0
research/astronet/astrowavenet/data/base_test.py
research/astronet/astrowavenet/data/base_test.py
+778
-0
research/astronet/astrowavenet/data/kepler_light_curves.py
research/astronet/astrowavenet/data/kepler_light_curves.py
+50
-0
research/astronet/astrowavenet/data/synthetic_transit_maker.py
...rch/astronet/astrowavenet/data/synthetic_transit_maker.py
+3
-3
research/astronet/astrowavenet/data/synthetic_transit_maker_test.py
...stronet/astrowavenet/data/synthetic_transit_maker_test.py
+8
-8
research/astronet/astrowavenet/data/synthetic_transits.py
research/astronet/astrowavenet/data/synthetic_transits.py
+72
-0
research/astronet/astrowavenet/data/test_data/test-dataset.tfrecord
...stronet/astrowavenet/data/test_data/test-dataset.tfrecord
+0
-0
research/astronet/astrowavenet/trainer.py
research/astronet/astrowavenet/trainer.py
+272
-0
research/astronet/astrowavenet/util/BUILD
research/astronet/astrowavenet/util/BUILD
+10
-0
research/astronet/astrowavenet/util/estimator_util.py
research/astronet/astrowavenet/util/estimator_util.py
+178
-0
research/astronet/light_curve_util/util.py
research/astronet/light_curve_util/util.py
+17
-19
No files found.
research/astronet/README.md
View file @
763663de
...
...
@@ -40,6 +40,10 @@ Full text available at [*The Astronomical Journal*](http://iopscience.iop.org/ar
*
Training and evaluating a new model.
*
Using a trained model to generate new predictions.
[
astrowavenet/
](
astrowavenet/
)
*
A generative model for light curves.
[
light_curve_util/
](
light_curve_util
)
*
Utilities for operating on light curves. These include:
...
...
@@ -63,11 +67,11 @@ First, ensure that you have installed the following required packages:
*
**TensorFlow**
(
[
instructions
](
https://www.tensorflow.org/install/
)
)
*
**Pandas**
(
[
instructions
](
http://pandas.pydata.org/pandas-docs/stable/install.html
)
)
*
**NumPy**
(
[
instructions
](
https://docs.scipy.org/doc/numpy/user/install.html
)
)
*
**SciPy**
(
[
instructions
](
https://scipy.org/install.html
)
)
*
**AstroPy**
(
[
instructions
](
http://www.astropy.org/
)
)
*
**PyDl**
(
[
instructions
](
https://pypi.python.org/pypi/pydl
)
)
*
**Bazel**
(
[
instructions
](
https://docs.bazel.build/versions/master/install.html
)
)
*
**Abseil Python Common Libraries**
(
[
instructions
](
https://github.com/abseil/abseil-py
)
)
*
Optional: only required for unit tests.
### Optional: Run Unit Tests
...
...
research/astronet/astronet/util/config_util.py
View file @
763663de
...
...
@@ -63,6 +63,14 @@ def parse_json(json_string_or_file):
return
json_dict
def
to_json
(
config
):
"""Converts a JSON-serializable configuration object to a JSON string."""
if
hasattr
(
config
,
"to_json"
)
and
callable
(
config
.
to_json
):
return
config
.
to_json
(
indent
=
2
)
else
:
return
json
.
dumps
(
config
,
indent
=
2
)
def
log_and_save_config
(
config
,
output_dir
):
"""Logs and writes a JSON-serializable configuration object.
...
...
@@ -70,10 +78,7 @@ def log_and_save_config(config, output_dir):
config: A JSON-serializable object.
output_dir: Destination directory.
"""
if
hasattr
(
config
,
"to_json"
)
and
callable
(
config
.
to_json
):
config_json
=
config
.
to_json
(
indent
=
2
)
else
:
config_json
=
json
.
dumps
(
config
,
indent
=
2
)
config_json
=
to_json
(
config
)
tf
.
logging
.
info
(
"config: %s"
,
config_json
)
tf
.
gfile
.
MakeDirs
(
output_dir
)
...
...
research/astronet/astrowavenet/BUILD
View file @
763663de
...
...
@@ -4,6 +4,22 @@ package(default_visibility = ["//visibility:public"])
licenses
([
"notice"
])
# Apache 2.0
py_binary
(
name
=
"trainer"
,
srcs
=
[
"trainer.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":astrowavenet_model"
,
":configurations"
,
"//astronet/util:config_util"
,
"//astronet/util:configdict"
,
"//astronet/util:estimator_runner"
,
"//astrowavenet/data:kepler_light_curves"
,
"//astrowavenet/data:synthetic_transits"
,
"//astrowavenet/util:estimator_util"
,
],
)
py_library
(
name
=
"configurations"
,
srcs
=
[
"configurations.py"
],
...
...
@@ -11,22 +27,22 @@ py_library(
)
py_library
(
name
=
"astrowavenet"
,
name
=
"astrowavenet
_model
"
,
srcs
=
[
"astrowavenet.py"
,
"astrowavenet
_model
.py"
,
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"astrowavenet_test"
,
name
=
"astrowavenet_
model_
test"
,
size
=
"small"
,
srcs
=
[
"astrowavenet_test.py"
,
"astrowavenet_
model_
test.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":astrowavenet"
,
":astrowavenet
_model
"
,
":configurations"
,
"//astronet/util:configdict"
,
],
...
...
research/astronet/astrowavenet/README.md
0 → 100644
View file @
763663de
# AstroWaveNet: A generative model for light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
## Code Authors
Alex Tamkin:
[
@atamkin
](
https://github.com/atamkin
)
Chris Shallue:
[
@cshallue
](
https://github.com/cshallue
)
## Pull Requests / Issues
Chris Shallue:
[
@cshallue
](
https://github.com/cshallue
)
## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this
requires the
**TensorFlow nightly build**
(
[
instructions
](
https://www.tensorflow.org/install/pip
)
).
In addition to the dependencies listed in the top-level README, this package
requires:
*
**TensorFlow Probability**
(
[
instructions
](
https://www.tensorflow.org/probability/install
)
)
*
**Six**
(
[
instructions
](
https://pypi.org/project/six/
)
)
## Basic Usage
To train a model on synthetic transits:
```
bash
bazel build astrowavenet/...
```
```
bash
bazel-bin/astrowavenet/trainer
\
--dataset
=
synthetic_transits
\
--model_dir
=
/tmp/astrowavenet/
\
--config_overrides
=
'{"hparams": {"batch_size": 16, "num_residual_blocks": 2}}'
\
--schedule
=
train_and_eval
\
--eval_steps
=
100
\
--save_checkpoints_steps
=
1000
```
research/astronet/astrowavenet/__init__.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astrowavenet/astrowavenet.py
→
research/astronet/astrowavenet/astrowavenet
_model
.py
View file @
763663de
...
...
@@ -23,6 +23,7 @@ from __future__ import division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
def
_shift_right
(
x
):
...
...
@@ -64,18 +65,21 @@ class AstroWaveNet(object):
tf
.
estimator
.
ModeKeys
.
PREDICT
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
'
Expected mode in {}. Got: {}
'
.
format
(
valid_modes
,
mode
))
raise
ValueError
(
"
Expected mode in {}. Got: {}
"
.
format
(
valid_modes
,
mode
))
self
.
hparams
=
hparams
self
.
mode
=
mode
self
.
autoregressive_input
=
features
[
'
autoregressive_input
'
]
self
.
conditioning_stack
=
features
[
'
conditioning_stack
'
]
self
.
weights
=
features
.
get
(
'
weights
'
)
self
.
autoregressive_input
=
features
[
"
autoregressive_input
"
]
self
.
conditioning_stack
=
features
[
"
conditioning_stack
"
]
self
.
weights
=
features
.
get
(
"
weights
"
)
self
.
network_output
=
None
# Sum of skip connections from dilation stack.
self
.
dist_params
=
None
# Dict of predicted distribution parameters.
self
.
predicted_distributions
=
None
# Predicted distribution for examples.
self
.
autoregressive_target
=
None
# Autoregressive target predictions.
self
.
batch_losses
=
None
# Loss for each predicted distribution in batch.
self
.
per_example_loss
=
None
# Loss for each example in batch.
self
.
num_nonzero_weight_examples
=
None
# Number of examples in batch.
self
.
total_loss
=
None
# Overall loss for the batch.
self
.
global_step
=
None
# Global step Tensor.
...
...
@@ -94,9 +98,9 @@ class AstroWaveNet(object):
causal_conv_op
=
tf
.
keras
.
layers
.
Conv1D
(
output_size
,
kernel_width
,
padding
=
'
causal
'
,
padding
=
"
causal
"
,
dilation_rate
=
dilation_rate
,
name
=
'
causal_conv
'
)
name
=
"
causal_conv
"
)
return
causal_conv_op
(
x
)
def
conv_1x1_layer
(
self
,
x
,
output_size
,
activation
=
None
):
...
...
@@ -111,7 +115,7 @@ class AstroWaveNet(object):
Resulting tf.Tensor after applying the 1x1 convolution.
"""
conv_1x1_op
=
tf
.
keras
.
layers
.
Conv1D
(
output_size
,
1
,
activation
=
activation
,
name
=
'
conv1x1
'
)
output_size
,
1
,
activation
=
activation
,
name
=
"
conv1x1
"
)
return
conv_1x1_op
(
x
)
def
gated_residual_layer
(
self
,
x
,
dilation_rate
):
...
...
@@ -125,24 +129,26 @@ class AstroWaveNet(object):
skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor.
"""
with
tf
.
variable_scope
(
'filter'
):
x_filter_conv
=
self
.
causal_conv_layer
(
x
,
int
(
x
.
shape
[
-
1
]),
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
with
tf
.
variable_scope
(
"filter"
):
x_filter_conv
=
self
.
causal_conv_layer
(
x
,
x
.
shape
[
-
1
].
value
,
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
cond_filter_conv
=
self
.
conv_1x1_layer
(
self
.
conditioning_stack
,
int
(
x
.
shape
[
-
1
]))
with
tf
.
variable_scope
(
'gate'
):
x_gate_conv
=
self
.
causal_conv_layer
(
x
,
int
(
x
.
shape
[
-
1
]),
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
x
.
shape
[
-
1
].
value
)
with
tf
.
variable_scope
(
"gate"
):
x_gate_conv
=
self
.
causal_conv_layer
(
x
,
x
.
shape
[
-
1
].
value
,
self
.
hparams
.
dilation_kernel_width
,
dilation_rate
)
cond_gate_conv
=
self
.
conv_1x1_layer
(
self
.
conditioning_stack
,
int
(
x
.
shape
[
-
1
]
)
)
x
.
shape
[
-
1
]
.
value
)
gated_activation
=
(
tf
.
tanh
(
x_filter_conv
+
cond_filter_conv
)
*
tf
.
sigmoid
(
x_gate_conv
+
cond_gate_conv
))
with
tf
.
variable_scope
(
'
residual
'
):
residual
=
self
.
conv_1x1_layer
(
gated_activation
,
int
(
x
.
shape
[
-
1
]
)
)
with
tf
.
variable_scope
(
'
skip
'
):
with
tf
.
variable_scope
(
"
residual
"
):
residual
=
self
.
conv_1x1_layer
(
gated_activation
,
x
.
shape
[
-
1
]
.
value
)
with
tf
.
variable_scope
(
"
skip
"
):
skip_connection
=
self
.
conv_1x1_layer
(
gated_activation
,
self
.
hparams
.
skip_output_dim
)
...
...
@@ -167,13 +173,13 @@ class AstroWaveNet(object):
"""
skip_connections
=
[]
x
=
_shift_right
(
self
.
autoregressive_input
)
with
tf
.
variable_scope
(
'
preprocess
'
):
with
tf
.
variable_scope
(
"
preprocess
"
):
x
=
self
.
causal_conv_layer
(
x
,
self
.
hparams
.
preprocess_output_size
,
self
.
hparams
.
preprocess_kernel_width
)
for
i
in
range
(
self
.
hparams
.
num_residual_blocks
):
with
tf
.
variable_scope
(
'
block_{}
'
.
format
(
i
)):
with
tf
.
variable_scope
(
"
block_{}
"
.
format
(
i
)):
for
dilation_rate
in
self
.
hparams
.
dilation_rates
:
with
tf
.
variable_scope
(
'
dilation_{}
'
.
format
(
dilation_rate
)):
with
tf
.
variable_scope
(
"
dilation_{}
"
.
format
(
dilation_rate
)):
skip_connection
,
x
=
self
.
gated_residual_layer
(
x
,
dilation_rate
)
skip_connections
.
append
(
skip_connection
)
...
...
@@ -192,7 +198,7 @@ class AstroWaveNet(object):
The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size].
"""
with
tf
.
variable_scope
(
'
dist_params
'
):
with
tf
.
variable_scope
(
"
dist_params
"
):
conv_outputs
=
self
.
conv_1x1_layer
(
x
,
outputs_size
)
return
conv_outputs
...
...
@@ -212,36 +218,40 @@ class AstroWaveNet(object):
self.network_outputs
Outputs:
self.dist_params
self.predicted_distributions
Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'.
"""
with
tf
.
variable_scope
(
'
postprocess
'
):
with
tf
.
variable_scope
(
"
postprocess
"
):
network_output
=
tf
.
keras
.
activations
.
relu
(
self
.
network_output
)
network_output
=
self
.
conv_1x1_layer
(
network_output
,
output_size
=
int
(
network_output
.
shape
[
-
1
]
)
,
activation
=
'
relu
'
)
num_dists
=
int
(
self
.
autoregressive_input
.
shape
[
-
1
]
)
output_size
=
network_output
.
shape
[
-
1
]
.
value
,
activation
=
"
relu
"
)
num_dists
=
self
.
autoregressive_input
.
shape
[
-
1
]
.
value
if
self
.
hparams
.
output_distribution
.
type
==
'
categorical
'
:
if
self
.
hparams
.
output_distribution
.
type
==
"
categorical
"
:
num_classes
=
self
.
hparams
.
output_distribution
.
num_classes
dist_params
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
num_classes
)
dist_shape
=
tf
.
concat
(
logits
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
num_classes
)
logits_shape
=
tf
.
concat
(
[
tf
.
shape
(
network_output
)[:
-
1
],
[
num_dists
,
num_classes
]],
0
)
dist_params
=
tf
.
reshape
(
dist_params
,
dist_shape
)
dist
=
tf
.
distributions
.
Categorical
(
logits
=
dist_params
)
elif
self
.
hparams
.
output_distribution
.
type
==
'normal'
:
dist_params
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
2
)
loc
,
scale
=
tf
.
split
(
dist_params
,
2
,
axis
=-
1
)
logits
=
tf
.
reshape
(
logits
,
logits_shape
)
dist
=
tfp
.
distributions
.
Categorical
(
logits
=
logits
)
dist_params
=
{
"logits"
:
logits
}
elif
self
.
hparams
.
output_distribution
.
type
==
"normal"
:
loc_scale
=
self
.
dist_params_layer
(
network_output
,
num_dists
*
2
)
loc
,
scale
=
tf
.
split
(
loc_scale
,
2
,
axis
=-
1
)
# Ensure scale is positive.
scale
=
tf
.
nn
.
softplus
(
scale
)
+
self
.
hparams
.
output_distribution
.
min_scale
dist
=
tf
.
distributions
.
Normal
(
loc
,
scale
)
dist
=
tfp
.
distributions
.
Normal
(
loc
,
scale
)
dist_params
=
{
"loc"
:
loc
,
"scale"
:
scale
}
else
:
raise
ValueError
(
'
Unsupported distribution type {}
'
.
format
(
raise
ValueError
(
"
Unsupported distribution type {}
"
.
format
(
self
.
hparams
.
output_distribution
.
type
))
self
.
dist_params
=
dist_params
self
.
predicted_distributions
=
dist
def
build_losses
(
self
):
...
...
@@ -257,7 +267,7 @@ class AstroWaveNet(object):
autoregressive_target
=
self
.
autoregressive_input
# Quantize the target if the output distribution is categorical.
if
self
.
hparams
.
output_distribution
.
type
==
'
categorical
'
:
if
self
.
hparams
.
output_distribution
.
type
==
"
categorical
"
:
min_val
=
self
.
hparams
.
output_distribution
.
min_quantization_value
max_val
=
self
.
hparams
.
output_distribution
.
max_quantization_value
num_classes
=
self
.
hparams
.
output_distribution
.
num_classes
...
...
@@ -270,7 +280,7 @@ class AstroWaveNet(object):
# final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals.
quantized_target
=
tf
.
where
(
quantized_target
=
=
num_classes
,
quantized_target
>
=
num_classes
,
tf
.
ones_like
(
quantized_target
)
*
(
num_classes
-
1
),
quantized_target
)
autoregressive_target
=
quantized_target
...
...
@@ -280,22 +290,24 @@ class AstroWaveNet(object):
if
weights
is
None
:
weights
=
tf
.
ones_like
(
log_prob
)
weights_dim
=
len
(
weights
.
shape
)
per_example_weight
=
tf
.
reduce_sum
(
weights
,
axis
=
range
(
1
,
weights_dim
))
per_example_weight
=
tf
.
reduce_sum
(
weights
,
axis
=
list
(
range
(
1
,
weights_dim
)))
per_example_indicator
=
tf
.
to_float
(
tf
.
greater
(
per_example_weight
,
0
))
num_examples
=
tf
.
reduce_sum
(
per_example_indicator
,
name
=
'num_nonzero_weight_examples'
)
num_examples
=
tf
.
reduce_sum
(
per_example_indicator
)
batch_losses
=
-
log_prob
*
weights
losses_dim
=
len
(
batch_losses
.
shape
)
losses_
n
dim
s
=
batch_losses
.
shape
.
ndims
per_example_loss_sum
=
tf
.
reduce_sum
(
batch_losses
,
axis
=
range
(
1
,
losses_dim
))
batch_losses
,
axis
=
list
(
range
(
1
,
losses_
n
dim
s
)
))
per_example_loss
=
tf
.
where
(
per_example_weight
>
0
,
per_example_loss_sum
/
per_example_weight
,
tf
.
zeros_like
(
per_example_weight
))
total_loss
=
tf
.
reduce_sum
(
per_example_loss
)
/
num_examples
self
.
autoregressive_target
=
autoregressive_target
self
.
batch_losses
=
batch_losses
self
.
per_example_loss
=
per_example_loss
self
.
num_nonzero_weight_examples
=
num_examples
self
.
total_loss
=
total_loss
def
build
(
self
):
...
...
research/astronet/astrowavenet/astrowavenet_test.py
→
research/astronet/astrowavenet/astrowavenet_
model_
test.py
View file @
763663de
This diff is collapsed.
Click to expand it.
research/astronet/astrowavenet/data/BUILD
View file @
763663de
...
...
@@ -2,6 +2,48 @@ package(default_visibility = ["//visibility:public"])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"base"
,
srcs
=
[
"base.py"
,
],
deps
=
[
"//astronet/ops:dataset_ops"
,
"//astronet/util:configdict"
,
],
)
py_test
(
name
=
"base_test"
,
srcs
=
[
"base_test.py"
],
data
=
[
"test_data/test-dataset.tfrecord"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":base"
],
)
py_library
(
name
=
"kepler_light_curves"
,
srcs
=
[
"kepler_light_curves.py"
,
],
deps
=
[
":base"
,
"//astronet/util:configdict"
,
],
)
py_library
(
name
=
"synthetic_transits"
,
srcs
=
[
"synthetic_transits.py"
,
],
deps
=
[
":base"
,
":synthetic_transit_maker"
,
"//astronet/util:configdict"
,
],
)
py_library
(
name
=
"synthetic_transit_maker"
,
srcs
=
[
...
...
research/astronet/astrowavenet/data/__init__.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
research/astronet/astrowavenet/data/base.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base dataset builder classes for AstroWaveNet input pipelines."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
abc
import
six
import
tensorflow
as
tf
from
astronet.util
import
configdict
from
astronet.ops
import
dataset_ops
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
DatasetBuilder
(
object
):
"""Base class for building a dataset input pipeline for AstroWaveNet."""
def
__init__
(
self
,
config_overrides
=
None
):
"""Initializes the dataset builder.
Args:
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
"""
self
.
config
=
configdict
.
ConfigDict
(
self
.
default_config
())
if
config_overrides
is
not
None
:
self
.
config
.
update
(
config_overrides
)
@
staticmethod
def
default_config
():
"""Returns the default configuration as a ConfigDict or Python dict."""
return
{}
@
abc
.
abstractmethod
def
build
(
self
,
batch_size
):
"""Builds the dataset input pipeline.
Args:
batch_size: The number of input examples in each batch.
Returns:
A tf.data.Dataset object.
"""
raise
NotImplementedError
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
_ShardedDatasetBuilder
(
DatasetBuilder
):
"""Abstract base class for a dataset consisting of sharded files."""
def
__init__
(
self
,
file_pattern
,
mode
,
config_overrides
=
None
,
use_tpu
=
False
):
"""Initializes the dataset builder.
Args:
file_pattern: File pattern matching input file shards, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
mode: A tf.estimator.ModeKeys.
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
use_tpu: Whether to build the dataset for TPU.
"""
super
(
_ShardedDatasetBuilder
,
self
).
__init__
(
config_overrides
)
self
.
file_pattern
=
file_pattern
self
.
mode
=
mode
self
.
use_tpu
=
use_tpu
@
staticmethod
def
default_config
():
config
=
super
(
_ShardedDatasetBuilder
,
_ShardedDatasetBuilder
).
default_config
()
config
.
update
({
"max_length"
:
1024
,
"shuffle_values_buffer"
:
1000
,
"num_parallel_parser_calls"
:
4
,
"batches_buffer_size"
:
None
,
# Defaults to max(1, 256 / batch_size).
})
return
config
@
abc
.
abstractmethod
def
file_reader
(
self
):
"""Returns a function that reads a single sharded file."""
raise
NotImplementedError
@
abc
.
abstractmethod
def
create_example_parser
(
self
):
"""Returns a function that parses a single tf.Example proto."""
raise
NotImplementedError
def
_batch_and_pad
(
self
,
dataset
,
batch_size
):
"""Combines elements into batches of the same length, padding if needed."""
if
self
.
use_tpu
:
padded_length
=
self
.
config
.
max_length
if
not
padded_length
:
raise
ValueError
(
"config.max_length is required when using TPU"
)
# Pad with zeros up to padded_length. Note that this will pad the
# "weights" Tensor with zeros as well, which ensures that padded elements
# do not contribute to the loss.
padded_shapes
=
{}
for
name
,
shape
in
dataset
.
output_shapes
.
iteritems
():
shape
.
assert_is_compatible_with
([
None
,
None
])
# Expect a 2D sequence.
dims
=
shape
.
as_list
()
dims
[
0
]
=
padded_length
shape
=
tf
.
TensorShape
(
dims
)
shape
.
assert_is_fully_defined
()
padded_shapes
[
name
]
=
shape
else
:
# Pad each batch up to the maximum size of each dimension in the batch.
padded_shapes
=
dataset
.
output_shapes
return
dataset
.
padded_batch
(
batch_size
,
padded_shapes
)
def
build
(
self
,
batch_size
):
"""Builds the dataset input pipeline.
Args:
batch_size:
Returns:
A tf.data.Dataset.
Raises:
ValueError: If no files match self.file_pattern.
"""
file_patterns
=
self
.
file_pattern
.
split
(
","
)
filenames
=
[]
for
p
in
file_patterns
:
matches
=
tf
.
gfile
.
Glob
(
p
)
if
not
matches
:
raise
ValueError
(
"Found no input files matching {}"
.
format
(
p
))
filenames
.
extend
(
matches
)
tf
.
logging
.
info
(
"Building input pipeline from %d files matching patterns: %s"
,
len
(
filenames
),
file_patterns
)
is_training
=
self
.
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
if
is_training
and
len
(
filenames
)
>
1
:
filename_dataset
=
filename_dataset
.
shuffle
(
len
(
filenames
))
# Read serialized Example protos.
dataset
=
filename_dataset
.
apply
(
tf
.
contrib
.
data
.
parallel_interleave
(
self
.
file_reader
(),
cycle_length
=
8
,
block_length
=
8
,
sloppy
=
True
))
if
is_training
:
# Shuffle and repeat. Note that shuffle() is before repeat(), so elements
# are shuffled among each epoch of data, and not between epochs of data.
if
self
.
config
.
shuffle_values_buffer
>
0
:
dataset
=
dataset
.
shuffle
(
self
.
config
.
shuffle_values_buffer
)
dataset
=
dataset
.
repeat
()
# Map the parser over the dataset.
dataset
=
dataset
.
map
(
self
.
create_example_parser
(),
num_parallel_calls
=
self
.
config
.
num_parallel_parser_calls
)
def
_prepare_wavenet_inputs
(
features
):
"""Validates features, and clips lengths and adds weights if needed."""
# Validate feature names.
required_features
=
{
"autoregressive_input"
,
"conditioning_stack"
}
allowed_features
=
required_features
|
{
"weights"
}
feature_names
=
features
.
keys
()
if
not
required_features
.
issubset
(
feature_names
):
raise
ValueError
(
"Features must contain all of: {}. Got: {}"
.
format
(
required_features
,
feature_names
))
if
not
allowed_features
.
issuperset
(
feature_names
):
raise
ValueError
(
"Features can only contain: {}. Got: {}"
.
format
(
allowed_features
,
feature_names
))
output
=
{}
for
name
,
value
in
features
.
items
():
# Validate shapes. The output dimension is [num_samples, dim].
ndims
=
len
(
value
.
shape
)
if
ndims
==
1
:
# Add an extra dimension: [num_samples] -> [num_samples, 1].
value
=
tf
.
expand_dims
(
value
,
-
1
)
elif
ndims
!=
2
:
raise
ValueError
(
"Features should be 1D or 2D sequences. Got '{}' = {}"
.
format
(
name
,
value
))
if
self
.
config
.
max_length
:
value
=
value
[:
self
.
config
.
max_length
]
output
[
name
]
=
value
if
"weights"
not
in
output
:
output
[
"weights"
]
=
tf
.
ones_like
(
output
[
"autoregressive_input"
])
return
output
dataset
=
dataset
.
map
(
_prepare_wavenet_inputs
)
# Batch results by up to batch_size.
dataset
=
self
.
_batch_and_pad
(
dataset
,
batch_size
)
if
is_training
:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset
=
dataset_ops
.
set_batch_size
(
dataset
,
batch_size
)
elif
self
.
use_tpu
and
self
.
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
# Pad to ensure that each batch has the same number of elements.
dataset
=
dataset_ops
.
pad_dataset_to_batch_size
(
dataset
,
batch_size
)
# Prefetch batches.
buffer_size
=
(
self
.
config
.
batches_buffer_size
or
max
(
1
,
int
(
256
/
batch_size
)))
dataset
=
dataset
.
prefetch
(
buffer_size
)
return
dataset
def
tfrecord_reader
(
filename
):
"""Returns a tf.data.Dataset that reads a single TFRecord file shard."""
return
tf
.
data
.
TFRecordDataset
(
filename
,
buffer_size
=
16
*
1000
*
1000
)
class
TFRecordDataset
(
_ShardedDatasetBuilder
):
"""Builder for a dataset consisting of TFRecord files."""
def
file_reader
(
self
):
"""Returns a function that reads a single file shard."""
return
tfrecord_reader
research/astronet/astrowavenet/data/base_test.py
0 → 100644
View file @
763663de
This diff is collapsed.
Click to expand it.
research/astronet/astrowavenet/data/kepler_light_curves.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kepler light curve inputs to the AstroWaveNet model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
astrowavenet.data
import
base
COND_INPUT_KEY
=
"mask"
AR_INPUT_KEY
=
"flux"
class
KeplerLightCurves
(
base
.
TFRecordDataset
):
"""Kepler light curve inputs to the AstroWaveNet model."""
def
create_example_parser
(
self
):
def
_example_parser
(
serialized
):
"""Parses a single tf.Example proto."""
features
=
tf
.
parse_single_example
(
serialized
,
features
=
{
AR_INPUT_KEY
:
tf
.
VarLenFeature
(
tf
.
float32
),
COND_INPUT_KEY
:
tf
.
VarLenFeature
(
tf
.
int64
),
})
# Extract values from SparseTensor objects.
autoregressive_input
=
features
[
AR_INPUT_KEY
].
values
conditioning_stack
=
tf
.
to_float
(
features
[
COND_INPUT_KEY
].
values
)
return
{
"autoregressive_input"
:
autoregressive_input
,
"conditioning_stack"
:
conditioning_stack
,
}
return
_example_parser
research/astronet/astrowavenet/data/synthetic_transit_maker.py
View file @
763663de
...
...
@@ -43,8 +43,8 @@ class SyntheticTransitMaker(object):
would translate the sine wave by half of the period. The most common
reason to override this would be to generate light curves
deterministically (with e.g. (0,0)).
noise_sd_range: A tuple of values in [0, 1) specifying the range of
standard
deviations for the Gaussian noise applied to the sine wave.
noise_sd_range: A tuple of values in [0, 1) specifying the range of
standard
deviations for the Gaussian noise applied to the sine wave.
"""
def
__init__
(
self
,
...
...
research/astronet/astrowavenet/data/synthetic_transit_maker_test.py
View file @
763663de
...
...
@@ -29,30 +29,30 @@ class SyntheticTransitMakerTest(absltest.TestCase):
def
testBadRangesRaiseExceptions
(
self
):
# Period range cannot contain negative values.
with
self
.
assertRaisesRegexp
(
ValueError
,
'
Period
'
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"
Period
"
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
period_range
=
(
-
1
,
10
))
# Amplitude range cannot contain negative values.
with
self
.
assertRaisesRegexp
(
ValueError
,
'
Amplitude
'
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"
Amplitude
"
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
amplitude_range
=
(
-
10
,
-
1
))
# Threshold ratio range must be contained in the half-open interval [0, 1).
with
self
.
assertRaisesRegexp
(
ValueError
,
'
Threshold ratio
'
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"
Threshold ratio
"
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
threshold_ratio_range
=
(
0
,
1
))
# Noise standard deviation range must only contain nonnegative values.
with
self
.
assertRaisesRegexp
(
ValueError
,
'
Noise standard deviation
'
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"
Noise standard deviation
"
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
noise_sd_range
=
(
-
1
,
1
))
# End of range may not be less than start.
invalid_range
=
(
0.2
,
0.1
)
range_args
=
[
'
period_range
'
,
'
threshold_ratio_range
'
,
'
amplitude_range
'
,
'
noise_sd_range
'
,
'
phase_range
'
"
period_range
"
,
"
threshold_ratio_range
"
,
"
amplitude_range
"
,
"
noise_sd_range
"
,
"
phase_range
"
]
for
range_arg
in
range_args
:
with
self
.
assertRaisesRegexp
(
ValueError
,
'
may not be less
'
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"
may not be less
"
):
synthetic_transit_maker
.
SyntheticTransitMaker
(
**
{
range_arg
:
invalid_range
})
...
...
@@ -106,5 +106,5 @@ class SyntheticTransitMakerTest(absltest.TestCase):
self
.
assertEqual
(
len
(
mask
),
100
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
absltest
.
main
()
research/astronet/astrowavenet/data/synthetic_transits.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic transit inputs to the AstroWaveNet model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
astronet.util
import
configdict
from
astrowavenet.data
import
base
from
astrowavenet.data
import
synthetic_transit_maker
def
_prepare_wavenet_inputs
(
light_curve
,
mask
):
"""Gathers synthetic transits into the format expected by AstroWaveNet."""
return
{
"autoregressive_input"
:
tf
.
expand_dims
(
light_curve
,
-
1
),
"conditioning_stack"
:
tf
.
expand_dims
(
mask
,
-
1
),
}
class
SyntheticTransits
(
base
.
DatasetBuilder
):
"""Synthetic transit inputs to the AstroWaveNet model."""
@
staticmethod
def
default_config
():
return
configdict
.
ConfigDict
({
"period_range"
:
(
0.5
,
4
),
"amplitude_range"
:
(
1
,
1
),
"threshold_ratio_range"
:
(
0
,
0.99
),
"phase_range"
:
(
0
,
1
),
"noise_sd_range"
:
(
0.1
,
0.1
),
"mask_probability"
:
0.1
,
"light_curve_time_range"
:
(
0
,
100
),
"light_curve_num_points"
:
1000
})
def
build
(
self
,
batch_size
):
transit_maker
=
synthetic_transit_maker
.
SyntheticTransitMaker
(
period_range
=
self
.
config
.
period_range
,
amplitude_range
=
self
.
config
.
amplitude_range
,
threshold_ratio_range
=
self
.
config
.
threshold_ratio_range
,
phase_range
=
self
.
config
.
phase_range
,
noise_sd_range
=
self
.
config
.
noise_sd_range
)
t_start
,
t_end
=
self
.
config
.
light_curve_time_range
time
=
np
.
linspace
(
t_start
,
t_end
,
self
.
config
.
light_curve_num_points
)
dataset
=
tf
.
data
.
Dataset
.
from_generator
(
transit_maker
.
random_light_curve_generator
(
time
,
mask_prob
=
self
.
config
.
mask_probability
),
output_types
=
(
tf
.
float32
,
tf
.
float32
),
output_shapes
=
(
tf
.
TensorShape
((
self
.
config
.
light_curve_num_points
,)),
tf
.
TensorShape
((
self
.
config
.
light_curve_num_points
,))))
dataset
=
dataset
.
map
(
_prepare_wavenet_inputs
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
-
1
)
return
dataset
research/astronet/astrowavenet/data/test_data/test-dataset.tfrecord
0 → 100644
View file @
763663de
File added
research/astronet/astrowavenet/trainer.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training and evaluating AstroWaveNet models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os.path
from
absl
import
flags
import
tensorflow
as
tf
from
astronet.util
import
config_util
from
astronet.util
import
configdict
from
astronet.util
import
estimator_runner
from
astrowavenet
import
astrowavenet_model
from
astrowavenet
import
configurations
from
astrowavenet.data
import
kepler_light_curves
from
astrowavenet.data
import
synthetic_transits
from
astrowavenet.util
import
estimator_util
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_enum
(
"dataset"
,
None
,
[
"synthetic_transits"
,
"kepler_light_curves"
],
"Dataset for training and/or evaluation."
)
flags
.
DEFINE_string
(
"model_dir"
,
None
,
"Base output directory."
)
flags
.
DEFINE_string
(
"train_files"
,
None
,
"Comma-separated list of file patterns matching the TFRecord files in the "
"training dataset."
)
flags
.
DEFINE_string
(
"eval_files"
,
None
,
"Comma-separated list of file patterns matching the TFRecord files in the "
"evaluation dataset."
)
flags
.
DEFINE_string
(
"config_name"
,
"base"
,
"Name of the AstroWaveNet configuration."
)
flags
.
DEFINE_string
(
"config_overrides"
,
"{}"
,
"JSON string or JSON file containing overrides to the base configuration."
)
flags
.
DEFINE_enum
(
"schedule"
,
None
,
[
"train"
,
"train_and_eval"
,
"continuous_eval"
],
"Schedule for running the model."
)
flags
.
DEFINE_string
(
"eval_name"
,
"val"
,
"Name of the evaluation task."
)
flags
.
DEFINE_integer
(
"train_steps"
,
None
,
"Total number of steps for training."
)
flags
.
DEFINE_integer
(
"eval_steps"
,
None
,
"Number of steps for each evaluation."
)
flags
.
DEFINE_integer
(
"local_eval_frequency"
,
1000
,
"The number of training steps in between evaluation runs. Only applies "
"when schedule == 'train_and_eval'."
)
flags
.
DEFINE_integer
(
"save_summary_steps"
,
None
,
"The frequency at which to save model summaries."
)
flags
.
DEFINE_integer
(
"save_checkpoints_steps"
,
None
,
"The frequency at which to save model checkpoints."
)
flags
.
DEFINE_integer
(
"save_checkpoints_secs"
,
None
,
"The frequency at which to save model checkpoints."
)
flags
.
DEFINE_integer
(
"keep_checkpoint_max"
,
1
,
"The maximum number of model checkpoints to keep."
)
# ------------------------------------------------------------------------------
# TPU-only flags
# ------------------------------------------------------------------------------
flags
.
DEFINE_boolean
(
"use_tpu"
,
False
,
"Whether to execute on TPU."
)
flags
.
DEFINE_string
(
"master"
,
None
,
"Address of the TensorFlow TPU master."
)
flags
.
DEFINE_integer
(
"tpu_num_shards"
,
8
,
"Number of TPU shards."
)
flags
.
DEFINE_integer
(
"tpu_iterations_per_loop"
,
1000
,
"Number of iterations per TPU training loop."
)
flags
.
DEFINE_integer
(
"eval_batch_size"
,
None
,
"Batch size for TPU evaluation. Defaults to the training batch size."
)
def
_create_run_config
():
"""Creates a TPU RunConfig if FLAGS.use_tpu is True, else a RunConfig."""
session_config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
)
run_config_kwargs
=
{
"save_summary_steps"
:
FLAGS
.
save_summary_steps
,
"save_checkpoints_steps"
:
FLAGS
.
save_checkpoints_steps
,
"save_checkpoints_secs"
:
FLAGS
.
save_checkpoints_secs
,
"session_config"
:
session_config
,
"keep_checkpoint_max"
:
FLAGS
.
keep_checkpoint_max
}
if
FLAGS
.
use_tpu
:
if
not
FLAGS
.
master
:
raise
ValueError
(
"FLAGS.master must be set for TPUEstimator."
)
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
iterations_per_loop
=
FLAGS
.
tpu_iterations_per_loop
,
num_shards
=
FLAGS
.
tpu_num_shards
,
per_host_input_for_training
=
(
FLAGS
.
tpu_num_shards
<=
8
))
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
tpu_config
=
tpu_config
,
master
=
FLAGS
.
master
,
**
run_config_kwargs
)
else
:
if
FLAGS
.
master
:
raise
ValueError
(
"FLAGS.master should only be set for TPUEstimator."
)
run_config
=
tf
.
estimator
.
RunConfig
(
**
run_config_kwargs
)
return
run_config
def
_get_file_pattern
(
mode
):
"""Gets the value of the file pattern flag for the specified mode."""
flag_name
=
(
"train_files"
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
else
"eval_files"
)
file_pattern
=
FLAGS
[
flag_name
].
value
if
file_pattern
is
None
:
raise
ValueError
(
"--{} is required for mode '{}'"
.
format
(
flag_name
,
mode
))
return
file_pattern
def
_create_dataset_builder
(
mode
,
config_overrides
=
None
):
"""Creates a dataset builder for the input pipeline."""
if
FLAGS
.
dataset
==
"synthetic_transits"
:
return
synthetic_transits
.
SyntheticTransits
(
config_overrides
)
file_pattern
=
_get_file_pattern
(
mode
)
if
FLAGS
.
dataset
==
"kepler_light_curves"
:
builder_class
=
kepler_light_curves
.
KeplerLightCurves
else
:
raise
ValueError
(
"Unsupported dataset: {}"
.
format
(
FLAGS
.
dataset
))
return
builder_class
(
file_pattern
,
mode
,
config_overrides
=
config_overrides
,
use_tpu
=
FLAGS
.
use_tpu
)
def
_create_input_fn
(
mode
,
config_overrides
=
None
):
"""Creates an Estimator input_fn."""
builder
=
_create_dataset_builder
(
mode
,
config_overrides
)
tf
.
logging
.
info
(
"Dataset config for mode '%s': %s"
,
mode
,
config_util
.
to_json
(
builder
.
config
))
return
estimator_util
.
create_input_fn
(
builder
)
def
_create_eval_args
(
config_overrides
=
None
):
"""Builds eval_args for estimator_runner.evaluate()."""
if
FLAGS
.
dataset
==
"synthetic_transits"
and
not
FLAGS
.
eval_steps
:
raise
ValueError
(
"Dataset '{}' requires --eval_steps for evaluation"
.
format
(
FLAGS
.
dataset
))
input_fn
=
_create_input_fn
(
tf
.
estimator
.
ModeKeys
.
EVAL
,
config_overrides
)
return
{
FLAGS
.
eval_name
:
(
input_fn
,
FLAGS
.
eval_steps
)}
def
main
(
argv
):
del
argv
# Unused.
config
=
configdict
.
ConfigDict
(
configurations
.
get_config
(
FLAGS
.
config_name
))
config_overrides
=
json
.
loads
(
FLAGS
.
config_overrides
)
for
key
in
config_overrides
:
if
key
not
in
[
"dataset"
,
"hparams"
]:
raise
ValueError
(
"Unrecognized config override: {}"
.
format
(
key
))
config
.
hparams
.
update
(
config_overrides
.
get
(
"hparams"
,
{}))
# Log configs.
configs_json
=
[
(
"config_overrides"
,
config_util
.
to_json
(
config_overrides
)),
(
"config"
,
config_util
.
to_json
(
config
)),
]
for
config_name
,
config_json
in
configs_json
:
tf
.
logging
.
info
(
"%s: %s"
,
config_name
,
config_json
)
# Create the estimator.
run_config
=
_create_run_config
()
estimator
=
estimator_util
.
create_estimator
(
astrowavenet_model
.
AstroWaveNet
,
config
.
hparams
,
run_config
,
FLAGS
.
model_dir
,
FLAGS
.
eval_batch_size
)
if
FLAGS
.
schedule
in
[
"train"
,
"train_and_eval"
]:
# Save configs.
tf
.
gfile
.
MakeDirs
(
FLAGS
.
model_dir
)
for
config_name
,
config_json
in
configs_json
:
filename
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
"{}.json"
.
format
(
config_name
))
with
tf
.
gfile
.
Open
(
filename
,
"w"
)
as
f
:
f
.
write
(
config_json
)
train_input_fn
=
_create_input_fn
(
tf
.
estimator
.
ModeKeys
.
TRAIN
,
config_overrides
.
get
(
"dataset"
))
train_hooks
=
[]
if
FLAGS
.
schedule
==
"train"
:
estimator
.
train
(
train_input_fn
,
hooks
=
train_hooks
,
max_steps
=
FLAGS
.
train_steps
)
else
:
assert
FLAGS
.
schedule
==
"train_and_eval"
eval_args
=
_create_eval_args
(
config_overrides
.
get
(
"dataset"
))
for
_
in
estimator_runner
.
continuous_train_and_eval
(
estimator
=
estimator
,
train_input_fn
=
train_input_fn
,
eval_args
=
eval_args
,
local_eval_frequency
=
FLAGS
.
local_eval_frequency
,
train_hooks
=
train_hooks
,
train_steps
=
FLAGS
.
train_steps
):
# continuous_train_and_eval() yields evaluation metrics after each
# FLAGS.local_eval_frequency. It also saves and logs them, so we don't
# do anything here.
pass
else
:
assert
FLAGS
.
schedule
==
"continuous_eval"
eval_args
=
_create_eval_args
(
config_overrides
.
get
(
"dataset"
))
for
_
in
estimator_runner
.
continuous_eval
(
estimator
=
estimator
,
eval_args
=
eval_args
,
train_steps
=
FLAGS
.
train_steps
):
# continuous_train_and_eval() yields evaluation metrics after each
# checkpoint. It also saves and logs them, so we don't do anything here.
pass
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
flags
.
mark_flags_as_required
([
"dataset"
,
"model_dir"
,
"schedule"
])
def
_validate_schedule
(
flag_values
):
"""Validates the --schedule flag and the flags it interacts with."""
schedule
=
flag_values
[
"schedule"
]
save_checkpoints_steps
=
flag_values
[
"save_checkpoints_steps"
]
save_checkpoints_secs
=
flag_values
[
"save_checkpoints_secs"
]
if
schedule
in
[
"train"
,
"train_and_eval"
]:
if
not
(
save_checkpoints_steps
or
save_checkpoints_secs
):
raise
flags
.
ValidationError
(
"--schedule='%s' requires --save_checkpoints_steps or "
"--save_checkpoints_secs."
%
schedule
)
return
True
flags
.
register_multi_flags_validator
(
[
"schedule"
,
"save_checkpoints_steps"
,
"save_checkpoints_secs"
],
_validate_schedule
)
tf
.
app
.
run
()
research/astronet/astrowavenet/util/BUILD
0 → 100644
View file @
763663de
package
(
default_visibility
=
[
"//visibility:public"
])
licenses
([
"notice"
])
# Apache 2.0
py_library
(
name
=
"estimator_util"
,
srcs
=
[
"estimator_util.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//astronet/ops:training"
],
)
research/astronet/astrowavenet/util/estimator_util.py
0 → 100644
View file @
763663de
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
tensorflow
as
tf
from
astronet.ops
import
training
class
_InputFn
(
object
):
"""Class that acts as a callable input function for Estimator train / eval."""
def
__init__
(
self
,
dataset_builder
):
"""Initializes the input function.
Args:
dataset_builder: Instance of DatasetBuilder.
"""
self
.
_builder
=
dataset_builder
def
__call__
(
self
,
params
):
"""Builds the input pipeline."""
return
self
.
_builder
.
build
(
batch_size
=
params
[
"batch_size"
])
def
create_input_fn
(
dataset_builder
):
"""Creates an input_fn that that builds an input pipeline.
Args:
dataset_builder: Instance of DatasetBuilder.
Returns:
A callable that builds an input pipeline and returns a tf.data.Dataset
object.
"""
return
_InputFn
(
dataset_builder
)
class
_ModelFn
(
object
):
"""Class that acts as a callable model function for Estimator train / eval."""
def
__init__
(
self
,
model_class
,
hparams
,
use_tpu
=
False
):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: A HParams object containing hyperparameters for building and
training the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self
.
_model_class
=
model_class
self
.
_base_hparams
=
hparams
self
.
_use_tpu
=
use_tpu
def
__call__
(
self
,
features
,
mode
,
params
):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams
=
copy
.
deepcopy
(
self
.
_base_hparams
)
if
"batch_size"
in
params
:
hparams
.
batch_size
=
params
[
"batch_size"
]
model
=
self
.
_model_class
(
features
,
hparams
,
mode
)
model
.
build
()
# Possibly create train_op.
use_tpu
=
self
.
_use_tpu
train_op
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
learning_rate
=
training
.
create_learning_rate
(
hparams
,
model
.
global_step
)
optimizer
=
training
.
create_optimizer
(
hparams
,
learning_rate
,
use_tpu
)
train_op
=
training
.
create_train_op
(
model
,
optimizer
)
if
use_tpu
:
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
model
.
total_loss
,
train_op
=
train_op
)
else
:
estimator
=
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
loss
=
model
.
total_loss
,
train_op
=
train_op
)
return
estimator
def
create_model_fn
(
model_class
,
hparams
,
use_tpu
=
False
):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return
_ModelFn
(
model_class
,
hparams
,
use_tpu
)
def
create_estimator
(
model_class
,
hparams
,
run_config
=
None
,
model_dir
=
None
,
eval_batch_size
=
None
):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroWaveNet or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if
run_config
is
None
:
run_config
=
tf
.
estimator
.
RunConfig
()
else
:
run_config
=
copy
.
deepcopy
(
run_config
)
if
not
model_dir
and
not
run_config
.
model_dir
:
raise
ValueError
(
"model_dir must be passed explicitly or specified in run_config"
)
use_tpu
=
isinstance
(
run_config
,
tf
.
contrib
.
tpu
.
RunConfig
)
model_fn
=
create_model_fn
(
model_class
,
hparams
,
use_tpu
)
if
use_tpu
:
eval_batch_size
=
eval_batch_size
or
hparams
.
batch_size
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
train_batch_size
=
hparams
.
batch_size
,
eval_batch_size
=
eval_batch_size
)
else
:
if
eval_batch_size
is
not
None
:
raise
ValueError
(
"eval_batch_size can only be specified for TPU."
)
estimator
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_fn
,
model_dir
=
model_dir
,
config
=
run_config
,
params
=
{
"batch_size"
:
hparams
.
batch_size
})
return
estimator
research/astronet/light_curve_util/util.py
View file @
763663de
...
...
@@ -220,14 +220,13 @@ def reshard_arrays(xs, ys):
return
np
.
split
(
concat_x
,
boundaries
)
def
uniform_cadence_light_curve
(
all_
cadence_no
,
all_
time
,
all_
flux
):
def
uniform_cadence_light_curve
(
cadence_no
,
time
,
flux
):
"""Combines data into a single light curve with uniform cadence numbers.
Args:
all_cadence_no: A list of numpy arrays; the cadence numbers of the light
curve.
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays; the flux values of the light curve.
cadence_no: numpy array; the cadence numbers of the light curve.
time: numpy array; the time values of the light curve.
flux: numpy array; the flux values of the light curve.
Returns:
cadence_no: numpy array; the cadence numbers of the light curve with no
...
...
@@ -245,16 +244,15 @@ def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux):
Raises:
ValueError: If there are duplicate cadence numbers in the input.
"""
min_cadence_no
=
np
.
min
(
[
np
.
min
(
c
)
for
c
in
all_
cadence_no
]
)
max_cadence_no
=
np
.
max
(
[
np
.
max
(
c
)
for
c
in
all_
cadence_no
]
)
min_cadence_no
=
np
.
min
(
cadence_no
)
max_cadence_no
=
np
.
max
(
cadence_no
)
out_cadence_no
=
np
.
arange
(
min_cadence_no
,
max_cadence_no
+
1
,
dtype
=
all_
cadence_no
[
0
]
.
dtype
)
out_time
=
np
.
zeros_like
(
out_cadence_no
,
dtype
=
all_
time
[
0
]
.
dtype
)
out_flux
=
np
.
zeros_like
(
out_cadence_no
,
dtype
=
all_
flux
[
0
]
.
dtype
)
min_cadence_no
,
max_cadence_no
+
1
,
dtype
=
cadence_no
.
dtype
)
out_time
=
np
.
zeros_like
(
out_cadence_no
,
dtype
=
time
.
dtype
)
out_flux
=
np
.
zeros_like
(
out_cadence_no
,
dtype
=
flux
.
dtype
)
out_mask
=
np
.
zeros_like
(
out_cadence_no
,
dtype
=
np
.
bool
)
for
cadence_no
,
time
,
flux
in
zip
(
all_cadence_no
,
all_time
,
all_flux
):
for
c
,
t
,
f
in
zip
(
cadence_no
,
time
,
flux
):
if
np
.
isfinite
(
c
)
and
np
.
isfinite
(
t
)
and
np
.
isfinite
(
f
):
i
=
int
(
c
-
min_cadence_no
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment