Commit 763663de authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Project import generated by Copybara.

PiperOrigin-RevId: 217341274
parent ca2db9bd
......@@ -40,6 +40,10 @@ Full text available at [*The Astronomical Journal*](http://iopscience.iop.org/ar
* Training and evaluating a new model.
* Using a trained model to generate new predictions.
[astrowavenet/](astrowavenet/)
* A generative model for light curves.
[light_curve_util/](light_curve_util)
* Utilities for operating on light curves. These include:
......@@ -63,11 +67,11 @@ First, ensure that you have installed the following required packages:
* **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
* **Pandas** ([instructions](http://pandas.pydata.org/pandas-docs/stable/install.html))
* **NumPy** ([instructions](https://docs.scipy.org/doc/numpy/user/install.html))
* **SciPy** ([instructions](https://scipy.org/install.html))
* **AstroPy** ([instructions](http://www.astropy.org/))
* **PyDl** ([instructions](https://pypi.python.org/pypi/pydl))
* **Bazel** ([instructions](https://docs.bazel.build/versions/master/install.html))
* **Abseil Python Common Libraries** ([instructions](https://github.com/abseil/abseil-py))
* Optional: only required for unit tests.
### Optional: Run Unit Tests
......
......@@ -63,6 +63,14 @@ def parse_json(json_string_or_file):
return json_dict
def to_json(config):
"""Converts a JSON-serializable configuration object to a JSON string."""
if hasattr(config, "to_json") and callable(config.to_json):
return config.to_json(indent=2)
else:
return json.dumps(config, indent=2)
def log_and_save_config(config, output_dir):
"""Logs and writes a JSON-serializable configuration object.
......@@ -70,10 +78,7 @@ def log_and_save_config(config, output_dir):
config: A JSON-serializable object.
output_dir: Destination directory.
"""
if hasattr(config, "to_json") and callable(config.to_json):
config_json = config.to_json(indent=2)
else:
config_json = json.dumps(config, indent=2)
config_json = to_json(config)
tf.logging.info("config: %s", config_json)
tf.gfile.MakeDirs(output_dir)
......
......@@ -4,6 +4,22 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_binary(
name = "trainer",
srcs = ["trainer.py"],
srcs_version = "PY2AND3",
deps = [
":astrowavenet_model",
":configurations",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astrowavenet/data:kepler_light_curves",
"//astrowavenet/data:synthetic_transits",
"//astrowavenet/util:estimator_util",
],
)
py_library(
name = "configurations",
srcs = ["configurations.py"],
......@@ -11,22 +27,22 @@ py_library(
)
py_library(
name = "astrowavenet",
name = "astrowavenet_model",
srcs = [
"astrowavenet.py",
"astrowavenet_model.py",
],
srcs_version = "PY2AND3",
)
py_test(
name = "astrowavenet_test",
name = "astrowavenet_model_test",
size = "small",
srcs = [
"astrowavenet_test.py",
"astrowavenet_model_test.py",
],
srcs_version = "PY2AND3",
deps = [
":astrowavenet",
":astrowavenet_model",
":configurations",
"//astronet/util:configdict",
],
......
# AstroWaveNet: A generative model for light curves.
Implementation based on "WaveNet: A Generative Model of Raw Audio":
https://arxiv.org/abs/1609.03499
## Code Authors
Alex Tamkin: [@atamkin](https://github.com/atamkin)
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Pull Requests / Issues
Chris Shallue: [@cshallue](https://github.com/cshallue)
## Additional Dependencies
This package requires TensorFlow 1.12 or greater. As of October 2018, this
requires the **TensorFlow nightly build**
([instructions](https://www.tensorflow.org/install/pip)).
In addition to the dependencies listed in the top-level README, this package
requires:
* **TensorFlow Probability** ([instructions](https://www.tensorflow.org/probability/install))
* **Six** ([instructions](https://pypi.org/project/six/))
## Basic Usage
To train a model on synthetic transits:
```bash
bazel build astrowavenet/...
```
```bash
bazel-bin/astrowavenet/trainer \
--dataset=synthetic_transits \
--model_dir=/tmp/astrowavenet/ \
--config_overrides='{"hparams": {"batch_size": 16, "num_residual_blocks": 2}}' \
--schedule=train_and_eval \
--eval_steps=100 \
--save_checkpoints_steps=1000
```
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow_probability as tfp
def _shift_right(x):
......@@ -64,18 +65,21 @@ class AstroWaveNet(object):
tf.estimator.ModeKeys.PREDICT
]
if mode not in valid_modes:
raise ValueError('Expected mode in {}. Got: {}'.format(valid_modes, mode))
raise ValueError("Expected mode in {}. Got: {}".format(valid_modes, mode))
self.hparams = hparams
self.mode = mode
self.autoregressive_input = features['autoregressive_input']
self.conditioning_stack = features['conditioning_stack']
self.weights = features.get('weights')
self.autoregressive_input = features["autoregressive_input"]
self.conditioning_stack = features["conditioning_stack"]
self.weights = features.get("weights")
self.network_output = None # Sum of skip connections from dilation stack.
self.dist_params = None # Dict of predicted distribution parameters.
self.predicted_distributions = None # Predicted distribution for examples.
self.autoregressive_target = None # Autoregressive target predictions.
self.batch_losses = None # Loss for each predicted distribution in batch.
self.per_example_loss = None # Loss for each example in batch.
self.num_nonzero_weight_examples = None # Number of examples in batch.
self.total_loss = None # Overall loss for the batch.
self.global_step = None # Global step Tensor.
......@@ -94,9 +98,9 @@ class AstroWaveNet(object):
causal_conv_op = tf.keras.layers.Conv1D(
output_size,
kernel_width,
padding='causal',
padding="causal",
dilation_rate=dilation_rate,
name='causal_conv')
name="causal_conv")
return causal_conv_op(x)
def conv_1x1_layer(self, x, output_size, activation=None):
......@@ -111,7 +115,7 @@ class AstroWaveNet(object):
Resulting tf.Tensor after applying the 1x1 convolution.
"""
conv_1x1_op = tf.keras.layers.Conv1D(
output_size, 1, activation=activation, name='conv1x1')
output_size, 1, activation=activation, name="conv1x1")
return conv_1x1_op(x)
def gated_residual_layer(self, x, dilation_rate):
......@@ -125,24 +129,26 @@ class AstroWaveNet(object):
skip_connection: tf.Tensor; Skip connection to network_output layer.
residual_connection: tf.Tensor; Sum of learned residual and input tensor.
"""
with tf.variable_scope('filter'):
x_filter_conv = self.causal_conv_layer(x, int(
x.shape[-1]), self.hparams.dilation_kernel_width, dilation_rate)
with tf.variable_scope("filter"):
x_filter_conv = self.causal_conv_layer(x, x.shape[-1].value,
self.hparams.dilation_kernel_width,
dilation_rate)
cond_filter_conv = self.conv_1x1_layer(self.conditioning_stack,
int(x.shape[-1]))
with tf.variable_scope('gate'):
x_gate_conv = self.causal_conv_layer(x, int(
x.shape[-1]), self.hparams.dilation_kernel_width, dilation_rate)
x.shape[-1].value)
with tf.variable_scope("gate"):
x_gate_conv = self.causal_conv_layer(x, x.shape[-1].value,
self.hparams.dilation_kernel_width,
dilation_rate)
cond_gate_conv = self.conv_1x1_layer(self.conditioning_stack,
int(x.shape[-1]))
x.shape[-1].value)
gated_activation = (
tf.tanh(x_filter_conv + cond_filter_conv) *
tf.sigmoid(x_gate_conv + cond_gate_conv))
with tf.variable_scope('residual'):
residual = self.conv_1x1_layer(gated_activation, int(x.shape[-1]))
with tf.variable_scope('skip'):
with tf.variable_scope("residual"):
residual = self.conv_1x1_layer(gated_activation, x.shape[-1].value)
with tf.variable_scope("skip"):
skip_connection = self.conv_1x1_layer(gated_activation,
self.hparams.skip_output_dim)
......@@ -167,13 +173,13 @@ class AstroWaveNet(object):
"""
skip_connections = []
x = _shift_right(self.autoregressive_input)
with tf.variable_scope('preprocess'):
with tf.variable_scope("preprocess"):
x = self.causal_conv_layer(x, self.hparams.preprocess_output_size,
self.hparams.preprocess_kernel_width)
for i in range(self.hparams.num_residual_blocks):
with tf.variable_scope('block_{}'.format(i)):
with tf.variable_scope("block_{}".format(i)):
for dilation_rate in self.hparams.dilation_rates:
with tf.variable_scope('dilation_{}'.format(dilation_rate)):
with tf.variable_scope("dilation_{}".format(dilation_rate)):
skip_connection, x = self.gated_residual_layer(x, dilation_rate)
skip_connections.append(skip_connection)
......@@ -192,7 +198,7 @@ class AstroWaveNet(object):
The parameters of each distribution, a tensor of shape [batch_size,
time_series_length, outputs_size].
"""
with tf.variable_scope('dist_params'):
with tf.variable_scope("dist_params"):
conv_outputs = self.conv_1x1_layer(x, outputs_size)
return conv_outputs
......@@ -212,36 +218,40 @@ class AstroWaveNet(object):
self.network_outputs
Outputs:
self.dist_params
self.predicted_distributions
Raises:
ValueError: If distribution type is neither 'categorical' nor 'normal'.
"""
with tf.variable_scope('postprocess'):
with tf.variable_scope("postprocess"):
network_output = tf.keras.activations.relu(self.network_output)
network_output = self.conv_1x1_layer(
network_output,
output_size=int(network_output.shape[-1]),
activation='relu')
num_dists = int(self.autoregressive_input.shape[-1])
output_size=network_output.shape[-1].value,
activation="relu")
num_dists = self.autoregressive_input.shape[-1].value
if self.hparams.output_distribution.type == 'categorical':
if self.hparams.output_distribution.type == "categorical":
num_classes = self.hparams.output_distribution.num_classes
dist_params = self.dist_params_layer(network_output,
num_dists * num_classes)
dist_shape = tf.concat(
logits = self.dist_params_layer(network_output, num_dists * num_classes)
logits_shape = tf.concat(
[tf.shape(network_output)[:-1], [num_dists, num_classes]], 0)
dist_params = tf.reshape(dist_params, dist_shape)
dist = tf.distributions.Categorical(logits=dist_params)
elif self.hparams.output_distribution.type == 'normal':
dist_params = self.dist_params_layer(network_output, num_dists * 2)
loc, scale = tf.split(dist_params, 2, axis=-1)
logits = tf.reshape(logits, logits_shape)
dist = tfp.distributions.Categorical(logits=logits)
dist_params = {"logits": logits}
elif self.hparams.output_distribution.type == "normal":
loc_scale = self.dist_params_layer(network_output, num_dists * 2)
loc, scale = tf.split(loc_scale, 2, axis=-1)
# Ensure scale is positive.
scale = tf.nn.softplus(scale) + self.hparams.output_distribution.min_scale
dist = tf.distributions.Normal(loc, scale)
dist = tfp.distributions.Normal(loc, scale)
dist_params = {"loc": loc, "scale": scale}
else:
raise ValueError('Unsupported distribution type {}'.format(
raise ValueError("Unsupported distribution type {}".format(
self.hparams.output_distribution.type))
self.dist_params = dist_params
self.predicted_distributions = dist
def build_losses(self):
......@@ -257,7 +267,7 @@ class AstroWaveNet(object):
autoregressive_target = self.autoregressive_input
# Quantize the target if the output distribution is categorical.
if self.hparams.output_distribution.type == 'categorical':
if self.hparams.output_distribution.type == "categorical":
min_val = self.hparams.output_distribution.min_quantization_value
max_val = self.hparams.output_distribution.max_quantization_value
num_classes = self.hparams.output_distribution.num_classes
......@@ -270,7 +280,7 @@ class AstroWaveNet(object):
# final quantized bucket a closed interval while all the other quantized
# buckets are half-open intervals.
quantized_target = tf.where(
quantized_target == num_classes,
quantized_target >= num_classes,
tf.ones_like(quantized_target) * (num_classes - 1), quantized_target)
autoregressive_target = quantized_target
......@@ -280,22 +290,24 @@ class AstroWaveNet(object):
if weights is None:
weights = tf.ones_like(log_prob)
weights_dim = len(weights.shape)
per_example_weight = tf.reduce_sum(weights, axis=range(1, weights_dim))
per_example_weight = tf.reduce_sum(
weights, axis=list(range(1, weights_dim)))
per_example_indicator = tf.to_float(tf.greater(per_example_weight, 0))
num_examples = tf.reduce_sum(
per_example_indicator, name='num_nonzero_weight_examples')
num_examples = tf.reduce_sum(per_example_indicator)
batch_losses = -log_prob * weights
losses_dim = len(batch_losses.shape)
losses_ndims = batch_losses.shape.ndims
per_example_loss_sum = tf.reduce_sum(
batch_losses, axis=range(1, losses_dim))
batch_losses, axis=list(range(1, losses_ndims)))
per_example_loss = tf.where(per_example_weight > 0,
per_example_loss_sum / per_example_weight,
tf.zeros_like(per_example_weight))
total_loss = tf.reduce_sum(per_example_loss) / num_examples
self.autoregressive_target = autoregressive_target
self.batch_losses = batch_losses
self.per_example_loss = per_example_loss
self.num_nonzero_weight_examples = num_examples
self.total_loss = total_loss
def build(self):
......
......@@ -2,6 +2,48 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "base",
srcs = [
"base.py",
],
deps = [
"//astronet/ops:dataset_ops",
"//astronet/util:configdict",
],
)
py_test(
name = "base_test",
srcs = ["base_test.py"],
data = ["test_data/test-dataset.tfrecord"],
srcs_version = "PY2AND3",
deps = [":base"],
)
py_library(
name = "kepler_light_curves",
srcs = [
"kepler_light_curves.py",
],
deps = [
":base",
"//astronet/util:configdict",
],
)
py_library(
name = "synthetic_transits",
srcs = [
"synthetic_transits.py",
],
deps = [
":base",
":synthetic_transit_maker",
"//astronet/util:configdict",
],
)
py_library(
name = "synthetic_transit_maker",
srcs = [
......
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base dataset builder classes for AstroWaveNet input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
import tensorflow as tf
from astronet.util import configdict
from astronet.ops import dataset_ops
@six.add_metaclass(abc.ABCMeta)
class DatasetBuilder(object):
"""Base class for building a dataset input pipeline for AstroWaveNet."""
def __init__(self, config_overrides=None):
"""Initializes the dataset builder.
Args:
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
"""
self.config = configdict.ConfigDict(self.default_config())
if config_overrides is not None:
self.config.update(config_overrides)
@staticmethod
def default_config():
"""Returns the default configuration as a ConfigDict or Python dict."""
return {}
@abc.abstractmethod
def build(self, batch_size):
"""Builds the dataset input pipeline.
Args:
batch_size: The number of input examples in each batch.
Returns:
A tf.data.Dataset object.
"""
raise NotImplementedError
@six.add_metaclass(abc.ABCMeta)
class _ShardedDatasetBuilder(DatasetBuilder):
"""Abstract base class for a dataset consisting of sharded files."""
def __init__(self, file_pattern, mode, config_overrides=None, use_tpu=False):
"""Initializes the dataset builder.
Args:
file_pattern: File pattern matching input file shards, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
mode: A tf.estimator.ModeKeys.
config_overrides: Dict or ConfigDict containing overrides to the default
configuration.
use_tpu: Whether to build the dataset for TPU.
"""
super(_ShardedDatasetBuilder, self).__init__(config_overrides)
self.file_pattern = file_pattern
self.mode = mode
self.use_tpu = use_tpu
@staticmethod
def default_config():
config = super(_ShardedDatasetBuilder,
_ShardedDatasetBuilder).default_config()
config.update({
"max_length": 1024,
"shuffle_values_buffer": 1000,
"num_parallel_parser_calls": 4,
"batches_buffer_size": None, # Defaults to max(1, 256 / batch_size).
})
return config
@abc.abstractmethod
def file_reader(self):
"""Returns a function that reads a single sharded file."""
raise NotImplementedError
@abc.abstractmethod
def create_example_parser(self):
"""Returns a function that parses a single tf.Example proto."""
raise NotImplementedError
def _batch_and_pad(self, dataset, batch_size):
"""Combines elements into batches of the same length, padding if needed."""
if self.use_tpu:
padded_length = self.config.max_length
if not padded_length:
raise ValueError("config.max_length is required when using TPU")
# Pad with zeros up to padded_length. Note that this will pad the
# "weights" Tensor with zeros as well, which ensures that padded elements
# do not contribute to the loss.
padded_shapes = {}
for name, shape in dataset.output_shapes.iteritems():
shape.assert_is_compatible_with([None, None]) # Expect a 2D sequence.
dims = shape.as_list()
dims[0] = padded_length
shape = tf.TensorShape(dims)
shape.assert_is_fully_defined()
padded_shapes[name] = shape
else:
# Pad each batch up to the maximum size of each dimension in the batch.
padded_shapes = dataset.output_shapes
return dataset.padded_batch(batch_size, padded_shapes)
def build(self, batch_size):
"""Builds the dataset input pipeline.
Args:
batch_size:
Returns:
A tf.data.Dataset.
Raises:
ValueError: If no files match self.file_pattern.
"""
file_patterns = self.file_pattern.split(",")
filenames = []
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching {}".format(p))
filenames.extend(matches)
tf.logging.info(
"Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
is_training = self.mode == tf.estimator.ModeKeys.TRAIN
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if is_training and len(filenames) > 1:
filename_dataset = filename_dataset.shuffle(len(filenames))
# Read serialized Example protos.
dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
self.file_reader(), cycle_length=8, block_length=8, sloppy=True))
if is_training:
# Shuffle and repeat. Note that shuffle() is before repeat(), so elements
# are shuffled among each epoch of data, and not between epochs of data.
if self.config.shuffle_values_buffer > 0:
dataset = dataset.shuffle(self.config.shuffle_values_buffer)
dataset = dataset.repeat()
# Map the parser over the dataset.
dataset = dataset.map(
self.create_example_parser(),
num_parallel_calls=self.config.num_parallel_parser_calls)
def _prepare_wavenet_inputs(features):
"""Validates features, and clips lengths and adds weights if needed."""
# Validate feature names.
required_features = {"autoregressive_input", "conditioning_stack"}
allowed_features = required_features | {"weights"}
feature_names = features.keys()
if not required_features.issubset(feature_names):
raise ValueError("Features must contain all of: {}. Got: {}".format(
required_features, feature_names))
if not allowed_features.issuperset(feature_names):
raise ValueError("Features can only contain: {}. Got: {}".format(
allowed_features, feature_names))
output = {}
for name, value in features.items():
# Validate shapes. The output dimension is [num_samples, dim].
ndims = len(value.shape)
if ndims == 1:
# Add an extra dimension: [num_samples] -> [num_samples, 1].
value = tf.expand_dims(value, -1)
elif ndims != 2:
raise ValueError(
"Features should be 1D or 2D sequences. Got '{}' = {}".format(
name, value))
if self.config.max_length:
value = value[:self.config.max_length]
output[name] = value
if "weights" not in output:
output["weights"] = tf.ones_like(output["autoregressive_input"])
return output
dataset = dataset.map(_prepare_wavenet_inputs)
# Batch results by up to batch_size.
dataset = self._batch_and_pad(dataset, batch_size)
if is_training:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset = dataset_ops.set_batch_size(dataset, batch_size)
elif self.use_tpu and self.mode == tf.estimator.ModeKeys.EVAL:
# Pad to ensure that each batch has the same number of elements.
dataset = dataset_ops.pad_dataset_to_batch_size(dataset, batch_size)
# Prefetch batches.
buffer_size = (
self.config.batches_buffer_size or max(1, int(256 / batch_size)))
dataset = dataset.prefetch(buffer_size)
return dataset
def tfrecord_reader(filename):
"""Returns a tf.data.Dataset that reads a single TFRecord file shard."""
return tf.data.TFRecordDataset(filename, buffer_size=16 * 1000 * 1000)
class TFRecordDataset(_ShardedDatasetBuilder):
"""Builder for a dataset consisting of TFRecord files."""
def file_reader(self):
"""Returns a function that reads a single file shard."""
return tfrecord_reader
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kepler light curve inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astrowavenet.data import base
COND_INPUT_KEY = "mask"
AR_INPUT_KEY = "flux"
class KeplerLightCurves(base.TFRecordDataset):
"""Kepler light curve inputs to the AstroWaveNet model."""
def create_example_parser(self):
def _example_parser(serialized):
"""Parses a single tf.Example proto."""
features = tf.parse_single_example(
serialized,
features={
AR_INPUT_KEY: tf.VarLenFeature(tf.float32),
COND_INPUT_KEY: tf.VarLenFeature(tf.int64),
})
# Extract values from SparseTensor objects.
autoregressive_input = features[AR_INPUT_KEY].values
conditioning_stack = tf.to_float(features[COND_INPUT_KEY].values)
return {
"autoregressive_input": autoregressive_input,
"conditioning_stack": conditioning_stack,
}
return _example_parser
......@@ -43,8 +43,8 @@ class SyntheticTransitMaker(object):
would translate the sine wave by half of the period. The most common
reason to override this would be to generate light curves
deterministically (with e.g. (0,0)).
noise_sd_range: A tuple of values in [0, 1) specifying the range of
standard deviations for the Gaussian noise applied to the sine wave.
noise_sd_range: A tuple of values in [0, 1) specifying the range of standard
deviations for the Gaussian noise applied to the sine wave.
"""
def __init__(self,
......@@ -125,7 +125,7 @@ class SyntheticTransitMaker(object):
Args:
time: An np.array of x-values to sample from the thresholded sine wave.
mask_prob: Value in [0,1], probability an individual datapoint is set to
zero.
zero.
Returns:
A generator yielding random light curves.
......
......@@ -29,30 +29,30 @@ class SyntheticTransitMakerTest(absltest.TestCase):
def testBadRangesRaiseExceptions(self):
# Period range cannot contain negative values.
with self.assertRaisesRegexp(ValueError, 'Period'):
with self.assertRaisesRegexp(ValueError, "Period"):
synthetic_transit_maker.SyntheticTransitMaker(period_range=(-1, 10))
# Amplitude range cannot contain negative values.
with self.assertRaisesRegexp(ValueError, 'Amplitude'):
with self.assertRaisesRegexp(ValueError, "Amplitude"):
synthetic_transit_maker.SyntheticTransitMaker(amplitude_range=(-10, -1))
# Threshold ratio range must be contained in the half-open interval [0, 1).
with self.assertRaisesRegexp(ValueError, 'Threshold ratio'):
with self.assertRaisesRegexp(ValueError, "Threshold ratio"):
synthetic_transit_maker.SyntheticTransitMaker(
threshold_ratio_range=(0, 1))
# Noise standard deviation range must only contain nonnegative values.
with self.assertRaisesRegexp(ValueError, 'Noise standard deviation'):
with self.assertRaisesRegexp(ValueError, "Noise standard deviation"):
synthetic_transit_maker.SyntheticTransitMaker(noise_sd_range=(-1, 1))
# End of range may not be less than start.
invalid_range = (0.2, 0.1)
range_args = [
'period_range', 'threshold_ratio_range', 'amplitude_range',
'noise_sd_range', 'phase_range'
"period_range", "threshold_ratio_range", "amplitude_range",
"noise_sd_range", "phase_range"
]
for range_arg in range_args:
with self.assertRaisesRegexp(ValueError, 'may not be less'):
with self.assertRaisesRegexp(ValueError, "may not be less"):
synthetic_transit_maker.SyntheticTransitMaker(
**{range_arg: invalid_range})
......@@ -106,5 +106,5 @@ class SyntheticTransitMakerTest(absltest.TestCase):
self.assertEqual(len(mask), 100)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic transit inputs to the AstroWaveNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import configdict
from astrowavenet.data import base
from astrowavenet.data import synthetic_transit_maker
def _prepare_wavenet_inputs(light_curve, mask):
"""Gathers synthetic transits into the format expected by AstroWaveNet."""
return {
"autoregressive_input": tf.expand_dims(light_curve, -1),
"conditioning_stack": tf.expand_dims(mask, -1),
}
class SyntheticTransits(base.DatasetBuilder):
"""Synthetic transit inputs to the AstroWaveNet model."""
@staticmethod
def default_config():
return configdict.ConfigDict({
"period_range": (0.5, 4),
"amplitude_range": (1, 1),
"threshold_ratio_range": (0, 0.99),
"phase_range": (0, 1),
"noise_sd_range": (0.1, 0.1),
"mask_probability": 0.1,
"light_curve_time_range": (0, 100),
"light_curve_num_points": 1000
})
def build(self, batch_size):
transit_maker = synthetic_transit_maker.SyntheticTransitMaker(
period_range=self.config.period_range,
amplitude_range=self.config.amplitude_range,
threshold_ratio_range=self.config.threshold_ratio_range,
phase_range=self.config.phase_range,
noise_sd_range=self.config.noise_sd_range)
t_start, t_end = self.config.light_curve_time_range
time = np.linspace(t_start, t_end, self.config.light_curve_num_points)
dataset = tf.data.Dataset.from_generator(
transit_maker.random_light_curve_generator(
time, mask_prob=self.config.mask_probability),
output_types=(tf.float32, tf.float32),
output_shapes=(tf.TensorShape((self.config.light_curve_num_points,)),
tf.TensorShape((self.config.light_curve_num_points,))))
dataset = dataset.map(_prepare_wavenet_inputs)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(-1)
return dataset
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training and evaluating AstroWaveNet models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
from absl import flags
import tensorflow as tf
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astrowavenet import astrowavenet_model
from astrowavenet import configurations
from astrowavenet.data import kepler_light_curves
from astrowavenet.data import synthetic_transits
from astrowavenet.util import estimator_util
FLAGS = flags.FLAGS
flags.DEFINE_enum("dataset", None,
["synthetic_transits", "kepler_light_curves"],
"Dataset for training and/or evaluation.")
flags.DEFINE_string("model_dir", None, "Base output directory.")
flags.DEFINE_string(
"train_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"training dataset.")
flags.DEFINE_string(
"eval_files", None,
"Comma-separated list of file patterns matching the TFRecord files in the "
"evaluation dataset.")
flags.DEFINE_string("config_name", "base",
"Name of the AstroWaveNet configuration.")
flags.DEFINE_string(
"config_overrides", "{}",
"JSON string or JSON file containing overrides to the base configuration.")
flags.DEFINE_enum("schedule", None,
["train", "train_and_eval", "continuous_eval"],
"Schedule for running the model.")
flags.DEFINE_string("eval_name", "val", "Name of the evaluation task.")
flags.DEFINE_integer("train_steps", None, "Total number of steps for training.")
flags.DEFINE_integer("eval_steps", None, "Number of steps for each evaluation.")
flags.DEFINE_integer(
"local_eval_frequency", 1000,
"The number of training steps in between evaluation runs. Only applies "
"when schedule == 'train_and_eval'.")
flags.DEFINE_integer("save_summary_steps", None,
"The frequency at which to save model summaries.")
flags.DEFINE_integer("save_checkpoints_steps", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("save_checkpoints_secs", None,
"The frequency at which to save model checkpoints.")
flags.DEFINE_integer("keep_checkpoint_max", 1,
"The maximum number of model checkpoints to keep.")
# ------------------------------------------------------------------------------
# TPU-only flags
# ------------------------------------------------------------------------------
flags.DEFINE_boolean("use_tpu", False, "Whether to execute on TPU.")
flags.DEFINE_string("master", None, "Address of the TensorFlow TPU master.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of TPU shards.")
flags.DEFINE_integer("tpu_iterations_per_loop", 1000,
"Number of iterations per TPU training loop.")
flags.DEFINE_integer(
"eval_batch_size", None,
"Batch size for TPU evaluation. Defaults to the training batch size.")
def _create_run_config():
"""Creates a TPU RunConfig if FLAGS.use_tpu is True, else a RunConfig."""
session_config = tf.ConfigProto(allow_soft_placement=True)
run_config_kwargs = {
"save_summary_steps": FLAGS.save_summary_steps,
"save_checkpoints_steps": FLAGS.save_checkpoints_steps,
"save_checkpoints_secs": FLAGS.save_checkpoints_secs,
"session_config": session_config,
"keep_checkpoint_max": FLAGS.keep_checkpoint_max
}
if FLAGS.use_tpu:
if not FLAGS.master:
raise ValueError("FLAGS.master must be set for TPUEstimator.")
tpu_config = tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.tpu_iterations_per_loop,
num_shards=FLAGS.tpu_num_shards,
per_host_input_for_training=(FLAGS.tpu_num_shards <= 8))
run_config = tf.contrib.tpu.RunConfig(
tpu_config=tpu_config, master=FLAGS.master, **run_config_kwargs)
else:
if FLAGS.master:
raise ValueError("FLAGS.master should only be set for TPUEstimator.")
run_config = tf.estimator.RunConfig(**run_config_kwargs)
return run_config
def _get_file_pattern(mode):
"""Gets the value of the file pattern flag for the specified mode."""
flag_name = ("train_files"
if mode == tf.estimator.ModeKeys.TRAIN else "eval_files")
file_pattern = FLAGS[flag_name].value
if file_pattern is None:
raise ValueError("--{} is required for mode '{}'".format(flag_name, mode))
return file_pattern
def _create_dataset_builder(mode, config_overrides=None):
"""Creates a dataset builder for the input pipeline."""
if FLAGS.dataset == "synthetic_transits":
return synthetic_transits.SyntheticTransits(config_overrides)
file_pattern = _get_file_pattern(mode)
if FLAGS.dataset == "kepler_light_curves":
builder_class = kepler_light_curves.KeplerLightCurves
else:
raise ValueError("Unsupported dataset: {}".format(FLAGS.dataset))
return builder_class(
file_pattern,
mode,
config_overrides=config_overrides,
use_tpu=FLAGS.use_tpu)
def _create_input_fn(mode, config_overrides=None):
"""Creates an Estimator input_fn."""
builder = _create_dataset_builder(mode, config_overrides)
tf.logging.info("Dataset config for mode '%s': %s", mode,
config_util.to_json(builder.config))
return estimator_util.create_input_fn(builder)
def _create_eval_args(config_overrides=None):
"""Builds eval_args for estimator_runner.evaluate()."""
if FLAGS.dataset == "synthetic_transits" and not FLAGS.eval_steps:
raise ValueError("Dataset '{}' requires --eval_steps for evaluation".format(
FLAGS.dataset))
input_fn = _create_input_fn(tf.estimator.ModeKeys.EVAL, config_overrides)
return {FLAGS.eval_name: (input_fn, FLAGS.eval_steps)}
def main(argv):
del argv # Unused.
config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name))
config_overrides = json.loads(FLAGS.config_overrides)
for key in config_overrides:
if key not in ["dataset", "hparams"]:
raise ValueError("Unrecognized config override: {}".format(key))
config.hparams.update(config_overrides.get("hparams", {}))
# Log configs.
configs_json = [
("config_overrides", config_util.to_json(config_overrides)),
("config", config_util.to_json(config)),
]
for config_name, config_json in configs_json:
tf.logging.info("%s: %s", config_name, config_json)
# Create the estimator.
run_config = _create_run_config()
estimator = estimator_util.create_estimator(
astrowavenet_model.AstroWaveNet, config.hparams, run_config,
FLAGS.model_dir, FLAGS.eval_batch_size)
if FLAGS.schedule in ["train", "train_and_eval"]:
# Save configs.
tf.gfile.MakeDirs(FLAGS.model_dir)
for config_name, config_json in configs_json:
filename = os.path.join(FLAGS.model_dir, "{}.json".format(config_name))
with tf.gfile.Open(filename, "w") as f:
f.write(config_json)
train_input_fn = _create_input_fn(tf.estimator.ModeKeys.TRAIN,
config_overrides.get("dataset"))
train_hooks = []
if FLAGS.schedule == "train":
estimator.train(
train_input_fn, hooks=train_hooks, max_steps=FLAGS.train_steps)
else:
assert FLAGS.schedule == "train_and_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_args=eval_args,
local_eval_frequency=FLAGS.local_eval_frequency,
train_hooks=train_hooks,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# FLAGS.local_eval_frequency. It also saves and logs them, so we don't
# do anything here.
pass
else:
assert FLAGS.schedule == "continuous_eval"
eval_args = _create_eval_args(config_overrides.get("dataset"))
for _ in estimator_runner.continuous_eval(
estimator=estimator, eval_args=eval_args,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# checkpoint. It also saves and logs them, so we don't do anything here.
pass
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
flags.mark_flags_as_required(["dataset", "model_dir", "schedule"])
def _validate_schedule(flag_values):
"""Validates the --schedule flag and the flags it interacts with."""
schedule = flag_values["schedule"]
save_checkpoints_steps = flag_values["save_checkpoints_steps"]
save_checkpoints_secs = flag_values["save_checkpoints_secs"]
if schedule in ["train", "train_and_eval"]:
if not (save_checkpoints_steps or save_checkpoints_secs):
raise flags.ValidationError(
"--schedule='%s' requires --save_checkpoints_steps or "
"--save_checkpoints_secs." % schedule)
return True
flags.register_multi_flags_validator(
["schedule", "save_checkpoints_steps", "save_checkpoints_secs"],
_validate_schedule)
tf.app.run()
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
srcs_version = "PY2AND3",
deps = ["//astronet/ops:training"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for creating a TensorFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from astronet.ops import training
class _InputFn(object):
"""Class that acts as a callable input function for Estimator train / eval."""
def __init__(self, dataset_builder):
"""Initializes the input function.
Args:
dataset_builder: Instance of DatasetBuilder.
"""
self._builder = dataset_builder
def __call__(self, params):
"""Builds the input pipeline."""
return self._builder.build(batch_size=params["batch_size"])
def create_input_fn(dataset_builder):
"""Creates an input_fn that that builds an input pipeline.
Args:
dataset_builder: Instance of DatasetBuilder.
Returns:
A callable that builds an input pipeline and returns a tf.data.Dataset
object.
"""
return _InputFn(dataset_builder)
class _ModelFn(object):
"""Class that acts as a callable model function for Estimator train / eval."""
def __init__(self, model_class, hparams, use_tpu=False):
"""Initializes the model function.
Args:
model_class: Model class.
hparams: A HParams object containing hyperparameters for building and
training the model.
use_tpu: If True, a TPUEstimator will be returned. Otherwise an Estimator
will be returned.
"""
self._model_class = model_class
self._base_hparams = hparams
self._use_tpu = use_tpu
def __call__(self, features, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
hparams = copy.deepcopy(self._base_hparams)
if "batch_size" in params:
hparams.batch_size = params["batch_size"]
model = self._model_class(features, hparams, mode)
model.build()
# Possibly create train_op.
use_tpu = self._use_tpu
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step)
optimizer = training.create_optimizer(hparams, learning_rate, use_tpu)
train_op = training.create_train_op(model, optimizer)
if use_tpu:
estimator = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
else:
estimator = tf.estimator.EstimatorSpec(
mode=mode, loss=model.total_loss, train_op=train_op)
return estimator
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
return _ModelFn(model_class, hparams, use_tpu)
def create_estimator(model_class,
hparams,
run_config=None,
model_dir=None,
eval_batch_size=None):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroWaveNet or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if run_config is None:
run_config = tf.estimator.RunConfig()
else:
run_config = copy.deepcopy(run_config)
if not model_dir and not run_config.model_dir:
raise ValueError(
"model_dir must be passed explicitly or specified in run_config")
use_tpu = isinstance(run_config, tf.contrib.tpu.RunConfig)
model_fn = create_model_fn(model_class, hparams, use_tpu)
if use_tpu:
eval_batch_size = eval_batch_size or hparams.batch_size
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
train_batch_size=hparams.batch_size,
eval_batch_size=eval_batch_size)
else:
if eval_batch_size is not None:
raise ValueError("eval_batch_size can only be specified for TPU.")
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
params={"batch_size": hparams.batch_size})
return estimator
......@@ -220,14 +220,13 @@ def reshard_arrays(xs, ys):
return np.split(concat_x, boundaries)
def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux):
def uniform_cadence_light_curve(cadence_no, time, flux):
"""Combines data into a single light curve with uniform cadence numbers.
Args:
all_cadence_no: A list of numpy arrays; the cadence numbers of the light
curve.
all_time: A list of numpy arrays; the time values of the light curve.
all_flux: A list of numpy arrays; the flux values of the light curve.
cadence_no: numpy array; the cadence numbers of the light curve.
time: numpy array; the time values of the light curve.
flux: numpy array; the flux values of the light curve.
Returns:
cadence_no: numpy array; the cadence numbers of the light curve with no
......@@ -245,24 +244,23 @@ def uniform_cadence_light_curve(all_cadence_no, all_time, all_flux):
Raises:
ValueError: If there are duplicate cadence numbers in the input.
"""
min_cadence_no = np.min([np.min(c) for c in all_cadence_no])
max_cadence_no = np.max([np.max(c) for c in all_cadence_no])
min_cadence_no = np.min(cadence_no)
max_cadence_no = np.max(cadence_no)
out_cadence_no = np.arange(
min_cadence_no, max_cadence_no + 1, dtype=all_cadence_no[0].dtype)
out_time = np.zeros_like(out_cadence_no, dtype=all_time[0].dtype)
out_flux = np.zeros_like(out_cadence_no, dtype=all_flux[0].dtype)
min_cadence_no, max_cadence_no + 1, dtype=cadence_no.dtype)
out_time = np.zeros_like(out_cadence_no, dtype=time.dtype)
out_flux = np.zeros_like(out_cadence_no, dtype=flux.dtype)
out_mask = np.zeros_like(out_cadence_no, dtype=np.bool)
for cadence_no, time, flux in zip(all_cadence_no, all_time, all_flux):
for c, t, f in zip(cadence_no, time, flux):
if np.isfinite(c) and np.isfinite(t) and np.isfinite(f):
i = int(c - min_cadence_no)
if out_mask[i]:
raise ValueError("Duplicate cadence number: {}".format(c))
out_time[i] = t
out_flux[i] = f
out_mask[i] = True
for c, t, f in zip(cadence_no, time, flux):
if np.isfinite(c) and np.isfinite(t) and np.isfinite(f):
i = int(c - min_cadence_no)
if out_mask[i]:
raise ValueError("Duplicate cadence number: {}".format(c))
out_time[i] = t
out_flux[i] = f
out_mask[i] = True
return out_cadence_no, out_time, out_flux, out_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment