Commit 763663de authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Project import generated by Copybara.

PiperOrigin-RevId: 217341274
parent ca2db9bd
...@@ -40,6 +40,10 @@ Full text available at [*The Astronomical Journal*](http://iopscience.iop.org/ar ...@@ -40,6 +40,10 @@ Full text available at [*The Astronomical Journal*](http://iopscience.iop.org/ar
* Training and evaluating a new model. * Training and evaluating a new model.
* Using a trained model to generate new predictions. * Using a trained model to generate new predictions.
[astrowavenet/](astrowavenet/)
* A generative model for light curves.
[light_curve_util/](light_curve_util) [light_curve_util/](light_curve_util)
* Utilities for operating on light curves. These include: * Utilities for operating on light curves. These include:
...@@ -63,11 +67,11 @@ First, ensure that you have installed the following required packages: ...@@ -63,11 +67,11 @@ First, ensure that you have installed the following required packages:
* **TensorFlow** ([instructions](https://www.tensorflow.org/install/)) * **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
* **Pandas** ([instructions](http://pandas.pydata.org/pandas-docs/stable/install.html)) * **Pandas** ([instructions](http://pandas.pydata.org/pandas-docs/stable/install.html))
* **NumPy** ([instructions](https://docs.scipy.org/doc/numpy/user/install.html)) * **NumPy** ([instructions](https://docs.scipy.org/doc/numpy/user/install.html))
* **SciPy** ([instructions](https://scipy.org/install.html))
* **AstroPy** ([instructions](http://www.astropy.org/)) * **AstroPy** ([instructions](http://www.astropy.org/))
* **PyDl** ([instructions](https://pypi.python.org/pypi/pydl)) * **PyDl** ([instructions](https://pypi.python.org/pypi/pydl))
* **Bazel** ([instructions](https://docs.bazel.build/versions/master/install.html)) * **Bazel** ([instructions](https://docs.bazel.build/versions/master/install.html))
* **Abseil Python Common Libraries** ([instructions](https://github.com/abseil/abseil-py)) * **Abseil Python Common Libraries** ([instructions](https://github.com/abseil/abseil-py))
* Optional: only required for unit tests.
### Optional: Run Unit Tests ### Optional: Run Unit Tests
......
...@@ -63,6 +63,14 @@ def parse_json(json_string_or_file): ...@@ -63,6 +63,14 @@ def parse_json(json_string_or_file):
return json_dict return json_dict
def to_json(config):
"""Converts a JSON-serializable configuration object to a JSON string."""
if hasattr(config, "to_json") and callable(config.to_json):
return config.to_json(indent=2)
else:
return json.dumps(config, indent=2)
def log_and_save_config(config, output_dir): def log_and_save_config(config, output_dir):
"""Logs and writes a JSON-serializable configuration object. """Logs and writes a JSON-serializable configuration object.
...@@ -70,10 +78,7 @@ def log_and_save_config(config, output_dir): ...@@ -70,10 +78,7 @@ def log_and_save_config(config, output_dir):
config: A JSON-serializable object. config: A JSON-serializable object.
output_dir: Destination directory. output_dir: Destination directory.
""" """
if hasattr(config, "to_json") and callable(config.to_json): config_json = to_json(config)
config_json = config.to_json(indent=2)
else:
config_json = json.dumps(config, indent=2)
tf.logging.info("config: %s", config_json) tf.logging.info("config: %s", config_json)
tf.gfile.MakeDirs(output_dir) tf.gfile.MakeDirs(output_dir)
......
...@@ -4,6 +4,22 @@ package(default_visibility = ["//visibility:public"]) ...@@ -4,6 +4,22 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
py_binary(
name = "trainer",
srcs = ["trainer.py"],
srcs_version = "PY2AND3",
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util",
],
)
py_library( py_library(
name = "configurations", name = "configurations",
srcs = ["configurations.py"], srcs = ["configurations.py"],
...@@ -11,22 +27,22 @@ py_library( ...@@ -11,22 +27,22 @@ py_library(
) )
py_library( py_library(
name = "astrowavenet", name = "astrowavenet_model",
srcs = [ srcs = [
"astrowavenet.py", "astrowavenet_model.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
) )
py_test( py_test(
name = "astrowavenet_test", name = "astrowavenet_model_test",
size = "small", size = "small",
srcs = [ srcs = [
"astrowavenet_test.py", "astrowavenet_model_test.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":astrowavenet", ":astrowavenet_model",
":configurations", ":configurations",
"//astronet/util:configdict", "//astronet/util:configdict",
], ],
......
# AstroWaveNet: A generative model for light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
## Code Authors
Alex Tamkin: [@atamkin](https://github.com/atamkin)
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Pull Requests / Issues
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this
requires the **TensorFlow nightly build**
([instructions](https://www.tensorflow.org/install/pip)).
In addition to the dependencies listed in the top-level README, this package
requires:
* **TensorFlow Probability** ([instructions](https://www.tensorflow.org/probability/install))
* **Six** ([instructions](https://pypi.org/project/six/))
## Basic Usage
To train a model on synthetic transits:
```bash
bazel build astrowavenet/...
```
```bash
bazel-bin/astrowavenet/trainer \
--dataset=synthetic_transits \
--model_dir=/tmp/astrowavenet/ \
--config_overrides='{"hparams": {"batch_size": 16, "num_residual_blocks": 2}}' \
--schedule=train_and_eval \
--eval_steps=100 \
--save_checkpoints_steps=1000
```
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -23,6 +23,7 @@ from __future__ import division ...@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp
def _shift_right(x): def _shift_right(x):
...@@ -64,18 +65,21 @@ class AstroWaveNet(object): ...@@ -64,18 +65,21 @@ class AstroWaveNet(object):
tf.estimator.ModeKeys.PREDICT tf.estimator.ModeKeys.PREDICT
] ]
if mode not in valid_modes: if mode not in valid_modes:
raise ValueError('Expected mode in {}. Got: {}'.format(valid_modes, mode)) raise ValueError("Expected mode in {}. Got: {}".format(valid_modes, mode))
self.hparams = hparams self.hparams = hparams
self.mode = mode self.mode = mode
self.autoregressive_input = features['autoregressive_input'] self.autoregressive_input = features["autoregressive_input"]
self.conditioning_stack = features['conditioning_stack'] self.conditioning_stack = features["conditioning_stack"]
self.weights = features.get('weights') self.weights = features.get("weights")
self.network_output = None # Sum of skip connections from dilation stack. self.network_output = None # Sum of skip connections from dilation stack.
self.dist_params = None # Dict of predicted distribution parameters.
self.predicted_distributions = None # Predicted distribution for examples. self.predicted_distributions = None # Predicted distribution for examples.
self.autoregressive_target = None # Autoregressive target predictions.
self.batch_losses = None # Loss for each predicted distribution in batch. self.batch_losses = None # Loss for each predicted distribution in batch.
self.per_example_loss = None # Loss for each example in batch. self.per_example_loss = None # Loss for each example in batch.
self.num_nonzero_weight_examples = None # Number of examples in batch.
self.total_loss = None # Overall loss for the batch. self.total_loss = None # Overall loss for the batch.
self.global_step = None # Global step Tensor. self.global_step = None # Global step Tensor.
...@@ -94,9 +98,9 @@ class AstroWaveNet(object): ...@@ -94,9 +98,9 @@ class AstroWaveNet(object):
causal_conv_op = tf.keras.layers.Conv1D( causal_conv_op = tf.keras.layers.Conv1D(
output_size, output_size,
kernel_width, kernel_width,
padding='causal', padding="causal",
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
name='causal_conv') name="causal_conv")
return causal_conv_op(x) return causal_conv_op(x)
def conv_1x1_layer(self, x, output_size, activation=None): def conv_1x1_layer(self, x, output_size, activation=None):
...@@ -111,7 +115,7 @@ class AstroWaveNet(object): ...@@ -111,7 +115,7 @@ class AstroWaveNet(object):
Resulting tf.Tensor after applying the 1x1 convolution. Resulting tf.Tensor after applying the 1x1 convolution.
""" """
conv_1x1_op = tf.keras.layers.Conv1D( conv_1x1_op = tf.keras.layers.Conv1D(
output_size, 1, activation=activation, name='conv1x1') output_size, 1, activation=activation, name="conv1x1")
return conv_1x1_op(x) return conv_1x1_op(x)
def gated_residual_layer(self, x, dilation_rate): def gated_residual_layer(self, x, dilation_rate):
...@@ -125,24 +129,26 @@ class AstroWaveNet(object): ...@@ -125,24 +129,26 @@ class AstroWaveNet(object):
skip_connection: tf.Tensor; Skip connection to network_output layer. skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor. residual_connection: tf.Tensor; Sum of learned residual and input tensor.
""" """
with tf.variable_scope('filter'): with tf.variable_scope("filter"):
x_filter_conv = self.causal_conv_layer(x, int( x_filter_conv = self.causal_conv_layer(x, x.shape[-1].value,
x.shape[-1]), self.hparams.dilation_kernel_width, dilation_rate) self.hparams.dilation_kernel_width,
dilation_rate)
cond_filter_conv = self.conv_1x1_layer(self.conditioning_stack, cond_filter_conv = self.conv_1x1_layer(self.conditioning_stack,
int(x.shape[-1])) x.shape[-1].value)
with tf.variable_scope('gate'): with tf.variable_scope("gate"):
x_gate_conv = self.causal_conv_layer(x, int( x_gate_conv = self.causal_conv_layer(x, x.shape[-1].value,
x.shape[-1]), self.hparams.dilation_kernel_width, dilation_rate) self.hparams.dilation_kernel_width,
dilation_rate)
cond_gate_conv = self.conv_1x1_layer(self.conditioning_stack, cond_gate_conv = self.conv_1x1_layer(self.conditioning_stack,
int(x.shape[-1])) x.shape[-1].value)
gated_activation = ( gated_activation = (
tf.tanh(x_filter_conv + cond_filter_conv) * tf.tanh(x_filter_conv + cond_filter_conv) *
tf.sigmoid(x_gate_conv + cond_gate_conv)) tf.sigmoid(x_gate_conv + cond_gate_conv))
with tf.variable_scope('residual'): with tf.variable_scope("residual"):
residual = self.conv_1x1_layer(gated_activation, int(x.shape[-1])) residual = self.conv_1x1_layer(gated_activation, x.shape[-1].value)
with tf.variable_scope('skip'): with tf.variable_scope("skip"):
skip_connection = self.conv_1x1_layer(gated_activation, skip_connection = self.conv_1x1_layer(gated_activation,
self.hparams.skip_output_dim) self.hparams.skip_output_dim)
...@@ -167,13 +173,13 @@ class AstroWaveNet(object): ...@@ -167,13 +173,13 @@ class AstroWaveNet(object):
""" """
skip_connections = [] skip_connections = []
x = _shift_right(self.autoregressive_input) x = _shift_right(self.autoregressive_input)
with tf.variable_scope('preprocess'): with tf.variable_scope("preprocess"):
x = self.causal_conv_layer(x, self.hparams.preprocess_output_size, x = self.causal_conv_layer(x, self.hparams.preprocess_output_size,
self.hparams.preprocess_kernel_width) self.hparams.preprocess_kernel_width)
for i in range(self.hparams.num_residual_blocks): for i in range(self.hparams.num_residual_blocks):
with tf.variable_scope('block_{}'.format(i)): with tf.variable_scope("block_{}".format(i)):
for dilation_rate in self.hparams.dilation_rates: for dilation_rate in self.hparams.dilation_rates:
with tf.variable_scope('dilation_{}'.format(dilation_rate)): with tf.variable_scope("dilation_{}".format(dilation_rate)):
skip_connection, x = self.gated_residual_layer(x, dilation_rate) skip_connection, x = self.gated_residual_layer(x, dilation_rate)
skip_connections.append(skip_connection) skip_connections.append(skip_connection)
...@@ -192,7 +198,7 @@ class AstroWaveNet(object): ...@@ -192,7 +198,7 @@ class AstroWaveNet(object):
The parameters of each distribution, a tensor of shape [batch_size, The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size]. time_series_length, outputs_size].
""" """
with tf.variable_scope('dist_params'): with tf.variable_scope("dist_params"):
conv_outputs = self.conv_1x1_layer(x, outputs_size) conv_outputs = self.conv_1x1_layer(x, outputs_size)
return conv_outputs return conv_outputs
...@@ -212,36 +218,40 @@ class AstroWaveNet(object): ...@@ -212,36 +218,40 @@ class AstroWaveNet(object):
self.network_outputs self.network_outputs
Outputs: Outputs:
self.dist_params
self.predicted_distributions self.predicted_distributions
Raises: Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'. ValueError: If distribution type is neither 'categorical' nor 'normal'.
""" """
with tf.variable_scope('postprocess'): with tf.variable_scope("postprocess"):
network_output = tf.keras.activations.relu(self.network_output) network_output = tf.keras.activations.relu(self.network_output)
network_output = self.conv_1x1_layer( network_output = self.conv_1x1_layer(
network_output, network_output,
output_size=int(network_output.shape[-1]), output_size=network_output.shape[-1].value,
activation='relu') activation="relu")
num_dists = int(self.autoregressive_input.shape[-1]) num_dists = self.autoregressive_input.shape[-1].value
if self.hparams.output_distribution.type == 'categorical': if self.hparams.output_distribution.type == "categorical":
num_classes = self.hparams.output_distribution.num_classes num_classes = self.hparams.output_distribution.num_classes
dist_params = self.dist_params_layer(network_output, logits = self.dist_params_layer(network_output, num_dists * num_classes)
num_dists * num_classes) logits_shape = tf.concat(
dist_shape = tf.concat(
[tf.shape(network_output)[:-1], [num_dists, num_classes]], 0) [tf.shape(network_output)[:-1], [num_dists, num_classes]], 0)
dist_params = tf.reshape(dist_params, dist_shape) logits = tf.reshape(logits, logits_shape)
dist = tf.distributions.Categorical(logits=dist_params) dist = tfp.distributions.Categorical(logits=logits)
elif self.hparams.output_distribution.type == 'normal': dist_params = {"logits": logits}
dist_params = self.dist_params_layer(network_output, num_dists * 2) elif self.hparams.output_distribution.type == "normal":
loc, scale = tf.split(dist_params, 2, axis=-1) loc_scale = self.dist_params_layer(network_output, num_dists * 2)
loc, scale = tf.split(loc_scale, 2, axis=-1)
# Ensure scale is positive. # Ensure scale is positive.
scale = tf.nn.softplus(scale) + self.hparams.output_distribution.min_scale scale = tf.nn.softplus(scale) + self.hparams.output_distribution.min_scale
dist = tf.distributions.Normal(loc, scale) dist = tfp.distributions.Normal(loc, scale)
dist_params = {"loc": loc, "scale": scale}
else: else:
raise ValueError('Unsupported distribution type {}'.format( raise ValueError("Unsupported distribution type {}".format(
self.hparams.output_distribution.type)) self.hparams.output_distribution.type))
self.dist_params = dist_params
self.predicted_distributions = dist self.predicted_distributions = dist
def build_losses(self): def build_losses(self):
...@@ -257,7 +267,7 @@ class AstroWaveNet(object): ...@@ -257,7 +267,7 @@ class AstroWaveNet(object):
autoregressive_target = self.autoregressive_input autoregressive_target = self.autoregressive_input
# Quantize the target if the output distribution is categorical. # Quantize the target if the output distribution is categorical.
if self.hparams.output_distribution.type == 'categorical': if self.hparams.output_distribution.type == "categorical":
min_val = self.hparams.output_distribution.min_quantization_value min_val = self.hparams.output_distribution.min_quantization_value
max_val = self.hparams.output_distribution.max_quantization_value max_val = self.hparams.output_distribution.max_quantization_value
num_classes = self.hparams.output_distribution.num_classes num_classes = self.hparams.output_distribution.num_classes
...@@ -270,7 +280,7 @@ class AstroWaveNet(object): ...@@ -270,7 +280,7 @@ class AstroWaveNet(object):
# final quantized bucket a closed interval while all the other quantized # final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals. # buckets are half-open intervals.
quantized_target = tf.where( quantized_target = tf.where(
quantized_target == num_classes, quantized_target >= num_classes,
tf.ones_like(quantized_target) * (num_classes - 1), quantized_target) tf.ones_like(quantized_target) * (num_classes - 1), quantized_target)
autoregressive_target = quantized_target autoregressive_target = quantized_target
...@@ -280,22 +290,24 @@ class AstroWaveNet(object): ...@@ -280,22 +290,24 @@ class AstroWaveNet(object):
if weights is None: if weights is None:
weights = tf.ones_like(log_prob) weights = tf.ones_like(log_prob)
weights_dim = len(weights.shape) weights_dim = len(weights.shape)
per_example_weight = tf.reduce_sum(weights, axis=range(1, weights_dim)) per_example_weight = tf.reduce_sum(
weights, axis=list(range(1, weights_dim)))
per_example_indicator = tf.to_float(tf.greater(per_example_weight, 0)) per_example_indicator = tf.to_float(tf.greater(per_example_weight, 0))
num_examples = tf.reduce_sum( num_examples = tf.reduce_sum(per_example_indicator)
per_example_indicator, name='num_nonzero_weight_examples')
batch_losses = -log_prob * weights batch_losses = -log_prob * weights
losses_dim = len(batch_losses.shape) losses_ndims = batch_losses.shape.ndims
per_example_loss_sum = tf.reduce_sum( per_example_loss_sum = tf.reduce_sum(
batch_losses, axis=range(1, losses_dim)) batch_losses, axis=list(range(1, losses_ndims)))
per_example_loss = tf.where(per_example_weight > 0, per_example_loss = tf.where(per_example_weight > 0,
per_example_loss_sum / per_example_weight, per_example_loss_sum / per_example_weight,
tf.zeros_like(per_example_weight)) tf.zeros_like(per_example_weight))
total_loss = tf.reduce_sum(per_example_loss) / num_examples total_loss = tf.reduce_sum(per_example_loss) / num_examples
self.autoregressive_target = autoregressive_target
self.batch_losses = batch_losses self.batch_losses = batch_losses
self.per_example_loss = per_example_loss self.per_example_loss = per_example_loss
self.num_nonzero_weight_examples = num_examples
self.total_loss = total_loss self.total_loss = total_loss
def build(self): def build(self):
......
...@@ -2,6 +2,48 @@ package(default_visibility = ["//visibility:public"]) ...@@ -2,6 +2,48 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
py_library(
name = "base",
srcs = [
"base.py",
],
deps = [
"//astronet/ops:dataset_ops",
"//astronet/util:configdict",
],
)
py_test(
name = "base_test",
srcs = ["base_test.py"],
data = ["test_data/test-dataset.tfrecord"],
srcs_version = "PY2AND3",
deps = [":base"],
)
py_library(
name = "kepler_light_curves",
srcs = [
"kepler_light_curves.py",
],
deps = [
":base",
"//astronet/util:configdict",
],
)
py_library(
name = "synthetic_transits",
srcs = [
"synthetic_transits.py",
],
deps = [
":base",
":synthetic_transit_maker",
"//astronet/util:configdict",
],
)
py_library( py_library(
name = "synthetic_transit_maker", name = "synthetic_transit_maker",
srcs = [ srcs = [
......
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base dataset builder classes for AstroWaveNet input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
import tensorflow as tf
from astronet.util import configdict
from astronet.ops import dataset_ops
@six.add_metaclass(abc.ABCMeta)
class DatasetBuilder(object):
"""Base class for building a dataset input pipeline for AstroWaveNet."""
def __init__(self, config_overrides=None):
"""Initializes the dataset builder.
Args:
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
"""
self.config = configdict.ConfigDict(self.default_config())
if config_overrides is not None:
self.config.update(config_overrides)
@staticmethod
def default_config():
"""Returns the default configuration as a ConfigDict or Python dict."""
return {}
@abc.abstractmethod
def build(self, batch_size):
"""Builds the dataset input pipeline.
Args:
batch_size: The number of input examples in each batch.
Returns:
A tf.data.Dataset object.
"""
raise NotImplementedError
@six.add_metaclass(abc.ABCMeta)
class _ShardedDatasetBuilder(DatasetBuilder):
"""Abstract base class for a dataset consisting of sharded files."""
def __init__(self, file_pattern, mode, config_overrides=None, use_tpu=False):
"""Initializes the dataset builder.
Args:
file_pattern: File pattern matching input file shards, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
mode: A tf.estimator.ModeKeys.
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
use_tpu: Whether to build the dataset for TPU.
"""
super(_ShardedDatasetBuilder, self).__init__(config_overrides)
self.file_pattern = file_pattern
self.mode = mode
self.use_tpu = use_tpu
@staticmethod
def default_config():
config = super(_ShardedDatasetBuilder,
_ShardedDatasetBuilder).default_config()
config.update({
"max_length": 1024,
"shuffle_values_buffer": 1000,
"num_parallel_parser_calls": 4,
"batches_buffer_size": None, # Defaults to max(1, 256 / batch_size).
})
return config
@abc.abstractmethod
def file_reader(self):
"""Returns a function that reads a single sharded file."""
raise NotImplementedError
@abc.abstractmethod
def create_example_parser(self):
"""Returns a function that parses a single tf.Example proto."""
raise NotImplementedError
def _batch_and_pad(self, dataset, batch_size):
"""Combines elements into batches of the same length, padding if needed."""
if self.use_tpu:
padded_length = self.config.max_length
if not padded_length:
raise ValueError("config.max_length is required when using TPU")
# Pad with zeros up to padded_length. Note that this will pad the
# "weights" Tensor with zeros as well, which ensures that padded elements
# do not contribute to the loss.
padded_shapes = {}
for name, shape in dataset.output_shapes.iteritems():
shape.assert_is_compatible_with([None, None]) # Expect a 2D sequence.
dims = shape.as_list()
dims[0] = padded_length
shape = tf.TensorShape(dims)
shape.assert_is_fully_defined()
padded_shapes[name] = shape
else:
# Pad each batch up to the maximum size of each dimension in the batch.
padded_shapes = dataset.output_shapes
return dataset.padded_batch(batch_size, padded_shapes)
def build(self, batch_size):
"""Builds the dataset input pipeline.
Args:
batch_size:
Returns:
A tf.data.Dataset.
Raises:
ValueError: If no files match self.file_pattern.
"""
file_patterns = self.file_pattern.split(",")
filenames = []
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching {}".format(p))
filenames.extend(matches)
tf.logging.info(
"Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
is_training = self.mode == tf.estimator.ModeKeys.TRAIN
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if is_training and len(filenames) > 1:
filename_dataset = filename_dataset.shuffle(len(filenames))
# Read serialized Example protos.
dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
self.file_reader(), cycle_length=8, block_length=8, sloppy=True))
if is_training:
# Shuffle and repeat. Note that shuffle() is before repeat(), so elements
# are shuffled among each epoch of data, and not between epochs of data.
if self.config.shuffle_values_buffer > 0:
dataset = dataset.shuffle(self.config.shuffle_values_buffer)
dataset = dataset.repeat()
# Map the parser over the dataset.
dataset = dataset.map(
self.create_example_parser(),
num_parallel_calls=self.config.num_parallel_parser_calls)
def _prepare_wavenet_inputs(features):
"""Validates features, and clips lengths and adds weights if needed."""
# Validate feature names.
required_features = {"autoregressive_input", "conditioning_stack"}
allowed_features = required_features | {"weights"}
feature_names = features.keys()
if not required_features.issubset(feature_names):
raise ValueError("Features must contain all of: {}. Got: {}".format(
required_features, feature_names))
if not allowed_features.issuperset(feature_names):
raise ValueError("Features can only contain: {}. Got: {}".format(
allowed_features, feature_names))
output = {}
for name, value in features.items():
# Validate shapes. The output dimension is [num_samples, dim].
ndims = len(value.shape)
if ndims == 1:
# Add an extra dimension: [num_samples] -> [num_samples, 1].
value = tf.expand_dims(value, -1)
elif ndims != 2:
raise ValueError(
"Features should be 1D or 2D sequences. Got '{}' = {}".format(
name, value))
if self.config.max_length:
value = value[:self.config.max_length]
output[name] = value
if "weights" not in output:
output["weights"] = tf.ones_like(output["autoregressive_input"])
return output
dataset = dataset.map(_prepare_wavenet_inputs)
# Batch results by up to batch_size.
dataset = self._batch_and_pad(dataset, batch_size)
if is_training:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset = dataset_ops.set_batch_size(dataset, batch_size)
elif self.use_tpu and self.mode == tf.estimator.ModeKeys.EVAL:
# Pad to ensure that each batch has the same number of elements.
dataset = dataset_ops.pad_dataset_to_batch_size(dataset, batch_size)
# Prefetch batches.
buffer_size = (
self.config.batches_buffer_size or max(1, int(256 / batch_size)))
dataset = dataset.prefetch(buffer_size)
return dataset
def tfrecord_reader(filename):
"""Returns a tf.data.Dataset that reads a single TFRecord file shard."""
return tf.data.TFRecordDataset(filename, buffer_size=16 * 1000 * 1000)
class TFRecordDataset(_ShardedDatasetBuilder):
"""Builder for a dataset consisting of TFRecord files."""
def file_reader(self):
"""Returns a function that reads a single file shard."""
return tfrecord_reader
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kepler light curve inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astrowavenet.data import base
COND_INPUT_KEY = "mask"
AR_INPUT_KEY = "flux"
class KeplerLightCurves(base.TFRecordDataset):
"""Kepler light curve inputs to the AstroWaveNet model."""
def create_example_parser(self):
def _example_parser(serialized):
"""Parses a single tf.Example proto."""
features = tf.parse_single_example(
serialized,
features={
AR_INPUT_KEY: tf.VarLenFeature(tf.float32),
COND_INPUT_KEY: tf.VarLenFeature(tf.int64),
})
# Extract values from SparseTensor objects.
autoregressive_input = features[AR_INPUT_KEY].values
conditioning_stack = tf.to_float(features[COND_INPUT_KEY].values)
return {
"autoregressive_input": autoregressive_input,
"conditioning_stack": conditioning_stack,
}
return _example_parser
...@@ -43,8 +43,8 @@ class SyntheticTransitMaker(object): ...@@ -43,8 +43,8 @@ class SyntheticTransitMaker(object):
would translate the sine wave by half of the period. The most common would translate the sine wave by half of the period. The most common
reason to override this would be to generate light curves reason to override this would be to generate light curves
deterministically (with e.g. (0,0)). deterministically (with e.g. (0,0)).
noise_sd_range: A tuple of values in [0, 1) specifying the range of noise_sd_range: A tuple of values in [0, 1) specifying the range of standard
standard deviations for the Gaussian noise applied to the sine wave. deviations for the Gaussian noise applied to the sine wave.
""" """
def __init__(self, def __init__(self,
......
...@@ -29,30 +29,30 @@ class SyntheticTransitMakerTest(absltest.TestCase): ...@@ -29,30 +29,30 @@ class SyntheticTransitMakerTest(absltest.TestCase):
def testBadRangesRaiseExceptions(self): def testBadRangesRaiseExceptions(self):
# Period range cannot contain negative values. # Period range cannot contain negative values.
with self.assertRaisesRegexp(ValueError, 'Period'): with self.assertRaisesRegexp(ValueError, "Period"):
synthetic_transit_maker.SyntheticTransitMaker(period_range=(-1, 10)) synthetic_transit_maker.SyntheticTransitMaker(period_range=(-1, 10))
# Amplitude range cannot contain negative values. # Amplitude range cannot contain negative values.
with self.assertRaisesRegexp(ValueError, 'Amplitude'): with self.assertRaisesRegexp(ValueError, "Amplitude"):
synthetic_transit_maker.SyntheticTransitMaker(amplitude_range=(-10, -1)) synthetic_transit_maker.SyntheticTransitMaker(amplitude_range=(-10, -1))
# Threshold ratio range must be contained in the half-open interval [0, 1). # Threshold ratio range must be contained in the half-open interval [0, 1).
with self.assertRaisesRegexp(ValueError, 'Threshold ratio'): with self.assertRaisesRegexp(ValueError, "Threshold ratio"):
synthetic_transit_maker.SyntheticTransitMaker( synthetic_transit_maker.SyntheticTransitMaker(
threshold_ratio_range=(0, 1)) threshold_ratio_range=(0, 1))
# Noise standard deviation range must only contain nonnegative values. # Noise standard deviation range must only contain nonnegative values.
with self.assertRaisesRegexp(ValueError, 'Noise standard deviation'): with self.assertRaisesRegexp(ValueError, "Noise standard deviation"):
synthetic_transit_maker.SyntheticTransitMaker(noise_sd_range=(-1, 1)) synthetic_transit_maker.SyntheticTransitMaker(noise_sd_range=(-1, 1))
# End of range may not be less than start. # End of range may not be less than start.
invalid_range = (0.2, 0.1) invalid_range = (0.2, 0.1)
range_args = [ range_args = [
'period_range', 'threshold_ratio_range', 'amplitude_range', "period_range", "threshold_ratio_range", "amplitude_range",
'noise_sd_range', 'phase_range' "noise_sd_range", "phase_range"
] ]
for range_arg in range_args: for range_arg in range_args:
with self.assertRaisesRegexp(ValueError, 'may not be less'): with self.assertRaisesRegexp(ValueError, "may not be less"):
synthetic_transit_maker.SyntheticTransitMaker( synthetic_transit_maker.SyntheticTransitMaker(
**{range_arg: invalid_range}) **{range_arg: invalid_range})
...@@ -106,5 +106,5 @@ class SyntheticTransitMakerTest(absltest.TestCase): ...@@ -106,5 +106,5 @@ class SyntheticTransitMakerTest(absltest.TestCase):
self.assertEqual(len(mask), 100) self.assertEqual(len(mask), 100)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic transit inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker
def _prepare_wavenet_inputs(light_curve, mask):
"""Gathers synthetic transits into the format expected by AstroWaveNet."""
return {
"autoregressive_input": tf.expand_dims(light_curve, -1),
"conditioning_stack": tf.expand_dims(mask, -1),
}
class SyntheticTransits(base.DatasetBuilder):
"""Synthetic transit inputs to the AstroWaveNet model."""
@staticmethod
def default_config():
return configdict.ConfigDict({
"period_range": (0.5, 4),
"amplitude_range": (1, 1),
"threshold_ratio_range": (0, 0.99),
"phase_range": (0, 1),
"noise_sd_range": (0.1, 0.1),
"mask_probability": 0.1,
"light_curve_time_range": (0, 100),
"light_curve_num_points": 1000
})
def build(self, batch_size):
transit_maker = synthetic_transit_maker.SyntheticTransitMaker(
period_range=self.config.period_range,
amplitude_range=self.config.amplitude_range,
threshold_ratio_range=self.config.threshold_ratio_range,
phase_range=self.config.phase_range,
noise_sd_range=self.config.noise_sd_range)
t_start, t_end = self.config.light_curve_time_range
time = np.linspace(t_start, t_end, self.config.light_curve_num_points)
dataset = tf.data.Dataset.from_generator(
transit_maker.random_light_curve_generator(
time, mask_prob=self.config.mask_probability),
output_types=(tf.float32, tf.float32),
output_shapes=(tf.TensorShape((self.config.light_curve_num_points,)),
tf.TensorShape((self.config.light_curve_num_points,))))
dataset = dataset.map(_prepare_wavenet_inputs)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(-1)
return dataset
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training and evaluating AstroWaveNet models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
from absl import flags
import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model
from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util
FLAGS = flags.FLAGS
flags.DEFINE_enum("dataset", None,
["synthetic_transits", "kepler_light_curves"],
"Dataset for training and/or evaluation.")
flags.DEFINE_string("model_dir", None, "Base output directory.")
flags.DEFINE_string(
"train_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"training dataset.")
flags.DEFINE_string(
"eval_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"evaluation dataset.")
flags.DEFINE_string("config_name", "base",
"Name of the AstroWaveNet configuration.")
flags.DEFINE_string(
"config_overrides", "{}",
"JSON string or JSON file containing overrides to the base configuration.")
flags.DEFINE_enum("schedule", None,
["train", "train_and_eval", "continuous_eval"],
"Schedule for running the model.")
flags.DEFINE_string("eval_name", "val", "Name of the evaluation task.")
flags.DEFINE_integer("train_steps", None, "Total number of steps for training.")
flags.DEFINE_integer("eval_steps", None, "Number of steps for each evaluation.")
flags.DEFINE_integer(
"local_eval_frequency", 1000,
"The number of training steps in between evaluation runs. Only applies "
"when schedule == 'train_and_eval'.")
flags.DEFINE_integer("save_summary_steps", None,
"The frequency at which to save model summaries.")
flags.DEFINE_integer("save_checkpoints_steps", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("save_checkpoints_secs", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("keep_checkpoint_max", 1,
"The maximum number of model checkpoints to keep.")
# ------------------------------------------------------------------------------
# TPU-only flags
# ------------------------------------------------------------------------------
flags.DEFINE_boolean("use_tpu", False, "Whether to execute on TPU.")
flags.DEFINE_string("master", None, "Address of the TensorFlow TPU master.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of TPU shards.")
flags.DEFINE_integer("tpu_iterations_per_loop", 1000,
"Number of iterations per TPU training loop.")
flags.DEFINE_integer(
"eval_batch_size", None,
"Batch size for TPU evaluation. Defaults to the training batch size.")
def _create_run_config():
"""Creates a TPU RunConfig if FLAGS.use_tpu is True, else a RunConfig."""
session_config = tf.ConfigProto(allow_soft_placement=True)
run_config_kwargs = {
"save_summary_steps": FLAGS.save_summary_steps,
"save_checkpoints_steps": FLAGS.save_checkpoints_steps,
"save_checkpoints_secs": FLAGS.save_checkpoints_secs,
"session_config": session_config,
"keep_checkpoint_max": FLAGS.keep_checkpoint_max
}
if FLAGS.use_tpu:
if not FLAGS.master:
raise ValueError("FLAGS.master must be set for TPUEstimator.")
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.tpu_iterations_per_loop,
num_shards=FLAGS.tpu_num_shards,
per_host_input_for_training=(FLAGS.tpu_num_shards <= 8))
run_config = tf.contrib.tpu.RunConfig(
tpu_config=tpu_config, master=FLAGS.master, **run_config_kwargs)
else:
if FLAGS.master:
raise ValueError("FLAGS.master should only be set for TPUEstimator.")
run_config = tf.estimator.RunConfig(**run_config_kwargs)
return run_config
def _get_file_pattern(mode):
"""Gets the value of the file pattern flag for the specified mode."""
flag_name = ("train_files"
if mode == tf.estimator.ModeKeys.TRAIN else "eval_files")
file_pattern = FLAGS[flag_name].value
if file_pattern is None:
raise ValueError("--{} is required for mode '{}'".format(flag_name, mode))
return file_pattern
def _create_dataset_builder(mode, config_overrides=None):
"""Creates a dataset builder for the input pipeline."""
if FLAGS.dataset == "synthetic_transits":
return synthetic_transits.SyntheticTransits(config_overrides)
file_pattern = _get_file_pattern(mode)
if FLAGS.dataset == "kepler_light_curves":
builder_class = kepler_light_curves.KeplerLightCurves
else:
raise ValueError("Unsupported dataset: {}".format(FLAGS.dataset))
return builder_class(
file_pattern,
mode,
config_overrides=config_overrides,
use_tpu=FLAGS.use_tpu)
def _create_input_fn(mode, config_overrides=None):
"""Creates an Estimator input_fn."""
builder = _create_dataset_builder(mode, config_overrides)
tf.logging.info("Dataset config for mode '%s': %s", mode,
config_util.to_json(builder.config))
return estimator_util.create_input_fn(builder)
def _create_eval_args(config_overrides=None):
"""Builds eval_args for estimator_runner.evaluate()."""
if FLAGS.dataset == "synthetic_transits" and not FLAGS.eval_steps:
raise ValueError("Dataset '{}' requires --eval_steps for evaluation".format(
FLAGS.dataset))
input_fn = _create_input_fn(tf.estimator.ModeKeys.EVAL, config_overrides)
return {FLAGS.eval_name: (input_fn, FLAGS.eval_steps)}
def main(argv):
del argv # Unused.
config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name))
config_overrides = json.loads(FLAGS.config_overrides)
for key in config_overrides:
if key not in ["dataset", "hparams"]:
raise ValueError("Unrecognized config override: {}".format(key))
config.hparams.update(config_overrides.get("hparams", {}))
# Log configs.
configs_json = [
("config_overrides", config_util.to_json(config_overrides)),
("config", config_util.to_json(config)),
]
for config_name, config_json in configs_json:
tf.logging.info("%s: %s", config_name, config_json)
# Create the estimator.
run_config = _create_run_config()
estimator = estimator_util.create_estimator(
astrowavenet_model.AstroWaveNet, config.hparams, run_config,
FLAGS.model_dir, FLAGS.eval_batch_size)
if FLAGS.schedule in ["train", "train_and_eval"]:
# Save configs.
tf.gfile.MakeDirs(FLAGS.model_dir)
for config_name, config_json in configs_json:
filename = os.path.join(FLAGS.model_dir, "{}.json".format(config_name))
with tf.gfile.Open(filename, "w") as f:
f.write(config_json)
train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN,
config_overrides.get("dataset"))
train_hooks = []
if FLAGS.schedule == "train":
estimator.train(
train_input_fn, hooks=train_hooks, max_steps=FLAGS.train_steps)
else:
assert FLAGS.schedule == "train_and_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_args=eval_args,
local_eval_frequency=FLAGS.local_eval_frequency,
train_hooks=train_hooks,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# FLAGS.local_eval_frequency. It also saves and logs them, so we don't
# do anything here.
pass
else:
assert FLAGS.schedule == "continuous_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_eval(
estimator=estimator, eval_args=eval_args,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# checkpoint. It also saves and logs them, so we don't do anything here.
pass
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
flags.mark_flags_as_required(["dataset", "model_dir", "schedule"])
def _validate_schedule(flag_values):
"""Validates the --schedule flag and the flags it interacts with."""
schedule = flag_values["schedule"]
save_checkpoints_steps = flag_values["save_checkpoints_steps"]
save_checkpoints_secs = flag_values["save_checkpoints_secs"]
if schedule in ["train", "train_and_eval"]:
if not (save_checkpoints_steps or save_checkpoints_secs):
raise flags.ValidationError(
"--schedule='%s' requires --save_checkpoints_steps or "
"--save_checkpoints_secs." % schedule)
return True
flags.register_multi_flags_validator(
["schedule", "save_checkpoints_steps", "save_checkpoints_secs"],
_validate_schedule)
tf.app.run()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
srcs_version = "PY2AND3",
deps = ["//astronet/ops:training"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from astronet.ops import training
class _InputFn(object):
"""Class that acts as a callable input function for Estimator train / eval."""
def __init__(self, dataset_builder):
"""Initializes the input function.
Args:
dataset_builder: Instance of DatasetBuilder.
"""
self._builder = dataset_builder
def __call__(self, params):
"""Builds the input pipeline."""
return self._builder.build(batch_size=params["batch_size"])
def create_input_fn(dataset_builder):
"""Creates an input_fn that that builds an input pipeline.
Args:
dataset_builder: Instance of DatasetBuilder.
Returns:
A callable that builds an input pipeline and returns a tf.data.Dataset
object.
"""
return _InputFn(dataset_builder)
class _ModelFn(object):
"""Class that acts as a callable model function for Estimator train / eval."""
def __init__(self, model_class, hparams, use_tpu=False):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: A HParams object containing hyperparameters for building and
training the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self._model_class = model_class
self._base_hparams = hparams
self._use_tpu = use_tpu
def __call__(self, features, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams = copy.deepcopy(self._base_hparams)
if "batch_size" in params:
hparams.batch_size = params["batch_size"]
model = self._model_class(features, hparams, mode)
model.build()
# Possibly create train_op.
use_tpu = self._use_tpu
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step)
optimizer = training.create_optimizer(hparams, learning_rate, use_tpu)
train_op = training.create_train_op(model, optimizer)
if use_tpu:
estimator = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
else:
estimator = tf.estimator.EstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
return estimator
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return _ModelFn(model_class, hparams, use_tpu)
def create_estimator(model_class,
hparams,
run_config=None,
model_dir=None,
eval_batch_size=None):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroWaveNet or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if run_config is None:
run_config = tf.estimator.RunConfig()
else:
run_config = copy.deepcopy(run_config)
if not model_dir and not run_config.model_dir:
raise ValueError(
"model_dir must be passed explicitly or specified in run_config")
use_tpu = isinstance(run_config, tf.contrib.tpu.RunConfig)
model_fn = create_model_fn(model_class, hparams, use_tpu)
if use_tpu:
eval_batch_size = eval_batch_size or hparams.batch_size
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
train_batch_size=hparams.batch_size,
eval_batch_size=eval_batch_size)
else:
if eval_batch_size is not None:
raise ValueError("eval_batch_size can only be specified for TPU.")
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
params={"batch_size": hparams.batch_size})
return estimator
...@@ -220,14 +220,13 @@ def reshard_arrays(xs, ys): ...@@ -220,14 +220,13 @@ def reshard_arrays(xs, ys):
return np.split(concat_x, boundaries) return np.split(concat_x, boundaries)
def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux): def uniform_cadence_light_curve(cadence_no, time, flux):
"""Combines data into a single light curve with uniform cadence numbers. """Combines data into a single light curve with uniform cadence numbers.
Args: Args:
all_cadence_no: A list of numpy arrays; the cadence numbers of the light cadence_no: numpy array; the cadence numbers of the light curve.
curve. time: numpy array; the time values of the light curve.
all_time: A list of numpy arrays; the time values of the light curve. flux: numpy array; the flux values of the light curve.
all_flux: A list of numpy arrays; the flux values of the light curve.
Returns: Returns:
cadence_no: numpy array; the cadence numbers of the light curve with no cadence_no: numpy array; the cadence numbers of the light curve with no
...@@ -245,16 +244,15 @@ def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux): ...@@ -245,16 +244,15 @@ def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux):
Raises: Raises:
ValueError: If there are duplicate cadence numbers in the input. ValueError: If there are duplicate cadence numbers in the input.
""" """
min_cadence_no = np.min([np.min(c) for c in all_cadence_no]) min_cadence_no = np.min(cadence_no)
max_cadence_no = np.max([np.max(c) for c in all_cadence_no]) max_cadence_no = np.max(cadence_no)
out_cadence_no = np.arange( out_cadence_no = np.arange(
min_cadence_no, max_cadence_no + 1, dtype=all_cadence_no[0].dtype) min_cadence_no, max_cadence_no + 1, dtype=cadence_no.dtype)
out_time = np.zeros_like(out_cadence_no, dtype=all_time[0].dtype) out_time = np.zeros_like(out_cadence_no, dtype=time.dtype)
out_flux = np.zeros_like(out_cadence_no, dtype=all_flux[0].dtype) out_flux = np.zeros_like(out_cadence_no, dtype=flux.dtype)
out_mask = np.zeros_like(out_cadence_no, dtype=np.bool) out_mask = np.zeros_like(out_cadence_no, dtype=np.bool)
for cadence_no, time, flux in zip(all_cadence_no, all_time, all_flux):
for c, t, f in zip(cadence_no, time, flux): for c, t, f in zip(cadence_no, time, flux):
if np.isfinite(c) and np.isfinite(t) and np.isfinite(f): if np.isfinite(c) and np.isfinite(t) and np.isfinite(f):
i = int(c - min_cadence_no) i = int(c - min_cadence_no)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment