Unverified Commit 69b01644 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5546 from cshallue/master

Improvements to AstroNet and add AstroWaveNet
parents 91b2debd 763663de
......@@ -40,6 +40,10 @@ Full text available at [*The Astronomical Journal*](http://iopscience.iop.org/ar
* Training and evaluating a new model.
* Using a trained model to generate new predictions.
[astrowavenet/](astrowavenet/)
* A generative model for light curves.
[light_curve_util/](light_curve_util)
* Utilities for operating on light curves. These include:
......@@ -63,11 +67,11 @@ First, ensure that you have installed the following required packages:
* **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
* **Pandas** ([instructions](http://pandas.pydata.org/pandas-docs/stable/install.html))
* **NumPy** ([instructions](https://docs.scipy.org/doc/numpy/user/install.html))
* **SciPy** ([instructions](https://scipy.org/install.html))
* **AstroPy** ([instructions](http://www.astropy.org/))
* **PyDl** ([instructions](https://pypi.python.org/pypi/pydl))
* **Bazel** ([instructions](https://docs.bazel.build/versions/master/install.html))
* **Abseil Python Common Libraries** ([instructions](https://github.com/abseil/abseil-py))
* Optional: only required for unit tests.
### Optional: Run Unit Tests
......@@ -207,7 +211,7 @@ the second deepest transits).
To train a model to identify exoplanets, you will need to provide TensorFlow
with training data in
[TFRecord](https://www.tensorflow.org/guide/datasets) format. The
[TFRecord](https://www.tensorflow.org/programmers_guide/datasets) format. The
TFRecord format consists of a set of sharded files containing serialized
`tf.Example` [protocol buffers](https://developers.google.com/protocol-buffers/).
......@@ -343,7 +347,7 @@ bazel-bin/astronet/train \
--model_dir=${MODEL_DIR}
```
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard)
Optionally, you can also run a [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard)
server in a separate process for real-time
monitoring of training progress and evaluation metrics.
......
......@@ -25,6 +25,7 @@ py_binary(
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
],
)
......@@ -37,6 +38,7 @@ py_binary(
":models",
"//astronet/util:config_util",
"//astronet/util:configdict",
"//astronet/util:estimator_runner",
"//astronet/util:estimator_util",
],
)
......
......@@ -54,24 +54,6 @@ from astronet.astro_model import astro_model
class AstroCNNModel(astro_model.AstroModel):
"""A model for classifying light curves using a convolutional neural net."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
super(AstroCNNModel, self).__init__(features, labels, hparams, mode)
def _build_cnn_layers(self, inputs, hparams, scope="cnn"):
"""Builds convolutional layers.
......@@ -95,7 +77,7 @@ class AstroCNNModel(astro_model.AstroModel):
for i in range(hparams.cnn_num_blocks):
num_filters = int(hparams.cnn_initial_num_filters *
hparams.cnn_block_filter_factor**i)
with tf.variable_scope("block_%d" % (i + 1)):
with tf.variable_scope("block_{}".format(i + 1)):
for j in range(hparams.cnn_block_size):
net = tf.layers.conv1d(
inputs=net,
......@@ -103,7 +85,7 @@ class AstroCNNModel(astro_model.AstroModel):
kernel_size=int(hparams.cnn_kernel_size),
padding=hparams.convolution_padding,
activation=tf.nn.relu,
name="conv_%d" % (j + 1))
name="conv_{}".format(j + 1))
if hparams.pool_size > 1: # pool_size 0 or 1 denotes no pooling
net = tf.layers.max_pooling1d(
......
......@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -58,24 +58,6 @@ from astronet.astro_model import astro_model
class AstroFCModel(astro_model.AstroModel):
"""A model for classifying light curves using fully connected layers."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
"""
super(AstroFCModel, self).__init__(features, labels, hparams, mode)
def _build_local_fc_layers(self, inputs, hparams, scope):
"""Builds locally fully connected layers.
......@@ -120,8 +102,8 @@ class AstroFCModel(astro_model.AstroModel):
elif hparams.pooling_type == "avg":
net = tf.reduce_mean(net, axis=1, name="avg_pool")
else:
raise ValueError(
"Unrecognized pooling_type: %s" % hparams.pooling_type)
raise ValueError("Unrecognized pooling_type: {}".format(
hparams.pooling_type))
remaining_layers = hparams.num_local_layers - 1
else:
......@@ -133,7 +115,7 @@ class AstroFCModel(astro_model.AstroModel):
inputs=net,
num_outputs=hparams.local_layer_size,
activation_fn=tf.nn.relu,
scope="fully_connected_%d" % (i + 1))
scope="fully_connected_{}".format(i + 1))
if hparams.dropout_rate > 0:
net = tf.layers.dropout(
......
......@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -73,17 +73,19 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
......@@ -93,7 +95,7 @@ class AstroModel(object):
tf.estimator.ModeKeys.PREDICT
]
if mode not in valid_modes:
raise ValueError("Expected mode in %s. Got: %s" % (valid_modes, mode))
raise ValueError("Expected mode in {}. Got: {}".format(valid_modes, mode))
self.hparams = hparams
self.mode = mode
......@@ -201,10 +203,9 @@ class AstroModel(object):
if len(hidden_layers) == 1:
pre_logits_concat = hidden_layers[0][1]
else:
pre_logits_concat = tf.concat(
[layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
pre_logits_concat = tf.concat([layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
net = pre_logits_concat
with tf.variable_scope("pre_logits_hidden"):
......@@ -213,7 +214,7 @@ class AstroModel(object):
inputs=net,
units=self.hparams.pre_logits_hidden_layer_size,
activation=tf.nn.relu,
name="fully_connected_%s" % (i + 1))
name="fully_connected_{}".format(i + 1))
if self.hparams.pre_logits_dropout_rate > 0:
net = tf.layers.dropout(
......
......@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -45,6 +45,9 @@ def base():
"PC": 1, # Planet Candidate.
"AFP": 0, # Astrophysical False Positive.
"NTP": 0, # Non-Transiting Phenomenon.
"SCR1": 0, # TCE from scrambled light curve with SCR1 order.
"INV": 0, # TCE from inverted light curve.
"INJ1": 1, # Injected Planet.
},
},
# Hyperparameters for building and training the model.
......@@ -60,10 +63,10 @@ def base():
"pre_logits_dropout_rate": 0.0,
# Number of examples per training batch.
"batch_size": 64,
"batch_size": 256,
# Learning rate parameters.
"learning_rate": 1e-5,
"learning_rate": 2e-4,
"learning_rate_decay_steps": 0,
"learning_rate_decay_factor": 0,
"learning_rate_decay_staircase": True,
......
......@@ -88,7 +88,6 @@ import tensorflow as tf
from astronet.data import preprocess
parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
......@@ -100,7 +99,7 @@ parser.add_argument(
required=True,
help="CSV file containing the Q1-Q17 DR24 Kepler TCE table. Must contain "
"columns: rowid, kepid, tce_plnt_num, tce_period, tce_duration, "
"tce_time0bk. Download from: %s" % _DR24_TCE_URL)
"tce_time0bk. Download from: {}".format(_DR24_TCE_URL))
parser.add_argument(
"--kepler_data_dir",
......@@ -219,14 +218,16 @@ def main(argv):
for i in range(FLAGS.num_train_shards):
start = boundaries[i]
end = boundaries[i + 1]
file_shards.append((train_tces[start:end], os.path.join(
FLAGS.output_dir, "train-%.5d-of-%.5d" % (i, FLAGS.num_train_shards))))
filename = os.path.join(
FLAGS.output_dir, "train-{:05d}-of-{:05d}".format(
i, FLAGS.num_train_shards))
file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir,
"val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir,
"test-00000-of-00001")))
file_shards.append((val_tces,
os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces,
os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards)
# Launch subprocesses for the file shards.
......
......@@ -34,7 +34,7 @@ def read_light_curve(kepid, kepler_data_dir):
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
kepler_io.kepler_filenames().
Returns:
all_time: A list of numpy arrays; the time values of the raw light curve.
......@@ -47,8 +47,8 @@ def read_light_curve(kepid, kepler_data_dir):
# Read the Kepler light curve.
file_names = kepler_io.kepler_filenames(kepler_data_dir, kepid)
if not file_names:
raise IOError("Failed to find .fits files in %s for Kepler ID %s" %
(kepler_data_dir, kepid))
raise IOError("Failed to find .fits files in {} for Kepler ID {}".format(
kepler_data_dir, kepid))
return kepler_io.read_kepler_light_curve(file_names)
......@@ -59,7 +59,7 @@ def process_light_curve(all_time, all_flux):
Args:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_time.
Returns:
time: 1D NumPy array; the time values of the light curve.
......@@ -192,7 +192,7 @@ def local_view(time,
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0).
event is assumed to be centered at 0).
Returns:
1D NumPy array of size num_bins containing the median flux values of
......@@ -214,7 +214,7 @@ def generate_example_for_tce(time, flux, tce):
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output.
'tce_time0bk'. Additional items are included as features in the output.
Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all
......
......@@ -26,6 +26,7 @@ import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_runner
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
......@@ -86,7 +87,9 @@ def main(_):
# Run evaluation. This will log the result to stderr and also write a summary
# file in the model_dir.
estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
eval_steps = None # Evaluate over all examples in the file.
eval_args = {FLAGS.eval_name: (input_fn, eval_steps)}
estimator_runner.evaluate(estimator, eval_args)
if __name__ == "__main__":
......
......@@ -46,7 +46,7 @@ def get_model_class(model_name):
ValueError: If model_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: %s" % model_name)
raise ValueError("Unrecognized model name: {}".format(model_name))
return _MODELS[model_name][0]
......@@ -57,7 +57,7 @@ def get_model_config(model_name, config_name):
Args:
model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's
configurations module.
configurations module.
Returns:
model_class: The requested model class.
......@@ -67,11 +67,12 @@ def get_model_config(model_name, config_name):
ValueError: If model_name or config_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: %s" % model_name)
raise ValueError("Unrecognized model name: {}".format(model_name))
config_module = _MODELS[model_name][1]
try:
return getattr(config_module, config_name)()
except AttributeError:
raise ValueError("Config name '%s' not found in configuration module: %s" %
(config_name, config_module.__name__))
raise ValueError(
"Config name '{}' not found in configuration module: {}".format(
config_name, config_module.__name__))
......@@ -69,7 +69,7 @@ def _recursive_pad_to_batch_size(tensor_or_collection, batch_size):
for t in tensor_or_collection
]
raise ValueError("Unknown input type: %s" % tensor_or_collection)
raise ValueError("Unknown input type: {}".format(tensor_or_collection))
def pad_dataset_to_batch_size(dataset, batch_size):
......@@ -119,7 +119,7 @@ def _recursive_set_batch_size(tensor_or_collection, batch_size):
for t in tensor_or_collection:
_recursive_set_batch_size(t, batch_size)
else:
raise ValueError("Unknown input type: %s" % tensor_or_collection)
raise ValueError("Unknown input type: {}".format(tensor_or_collection))
return tensor_or_collection
......@@ -142,19 +142,19 @@ def build_dataset(file_pattern,
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs.
epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely.
will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU.
Raises:
......@@ -170,7 +170,7 @@ def build_dataset(file_pattern,
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching %s" % p)
raise ValueError("Found no input files matching {}".format(p))
filenames.extend(matches)
tf.logging.info("Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
......@@ -180,8 +180,8 @@ def build_dataset(file_pattern,
label_ids = set(input_config.label_map.values())
if label_ids != set(range(len(label_ids))):
raise ValueError(
"Label IDs must be contiguous integers starting at 0. Got: %s" %
label_ids)
"Label IDs must be contiguous integers starting at 0. Got: {}".format(
label_ids))
# Create a HashTable mapping label strings to integer ids.
table_initializer = tf.contrib.lookup.KeyValueTensorInitializer(
......
......@@ -48,11 +48,6 @@ class DatasetOpsTest(tf.test.TestCase):
self.assertEqual([5], tensor_1d.shape)
self.assertAllEqual([0, 1, 2, 3, 4], tensor_1d.eval())
# Invalid to pad Tensor with batch size 5 to batch size 3.
tensor_1d_pad3 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 3)
with self.assertRaises(tf.errors.InvalidArgumentError):
tensor_1d_pad3.eval()
tensor_1d_pad5 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 5)
self.assertEqual([5], tensor_1d_pad5.shape)
self.assertAllEqual([0, 1, 2, 3, 4], tensor_1d_pad5.eval())
......@@ -66,11 +61,6 @@ class DatasetOpsTest(tf.test.TestCase):
self.assertEqual([3, 3], tensor_2d.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], tensor_2d.eval())
tensor_2d_pad2 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 2)
# Invalid to pad Tensor with batch size 2 to batch size 2.
with self.assertRaises(tf.errors.InvalidArgumentError):
tensor_2d_pad2.eval()
tensor_2d_pad3 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 3)
self.assertEqual([3, 3], tensor_2d_pad3.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]],
......
......@@ -27,11 +27,10 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape
[batch_size, length].
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
......
......@@ -31,9 +31,9 @@ class InputOpsTest(tf.test.TestCase):
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
input_ops.build_feature_placeholders().
"""
actual_shapes = {}
for feature_type in features:
......
......@@ -30,7 +30,7 @@ def _metric_variable(name, shape, dtype):
collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES])
def _build_metrics(labels, predictions, weights, batch_losses):
def _build_metrics(labels, predictions, weights, batch_losses, output_dim=1):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
......@@ -38,14 +38,16 @@ def _build_metrics(labels, predictions, weights, batch_losses):
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert len(predictions.shape) == 2
binary_classification = (predictions.shape[1] == 1)
binary_classification = output_dim == 1
if binary_classification:
assert predictions.shape[1] == 1
predictions = tf.squeeze(predictions, axis=[1])
predicted_labels = tf.to_int32(
tf.greater(predictions, 0.5), name="predicted_labels")
......@@ -73,35 +75,31 @@ def _build_metrics(labels, predictions, weights, batch_losses):
metrics["losses/weighted_cross_entropy"] = tf.metrics.mean(
batch_losses, weights=weights, name="cross_entropy_loss")
# Possibly create additional metrics for binary classification.
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
num_labels = 2 if binary_classification else output_dim
for gold_label in range(num_labels):
for pred_label in range(num_labels):
metric_name = "confusion_matrix/label_{}_pred_{}".format(
gold_label, pred_label)
metrics[metric_name] = _count_condition(
metric_name, labels_value=gold_label, predicted_value=pred_label)
# Possibly create AUC metric for binary classification.
if binary_classification:
labels = tf.cast(labels, dtype=tf.bool)
predicted_labels = tf.cast(predicted_labels, dtype=tf.bool)
# AUC.
metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000)
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
metrics["confusion_matrix/true_positives"] = _count_condition(
"true_positives", labels_value=True, predicted_value=True)
metrics["confusion_matrix/false_positives"] = _count_condition(
"false_positives", labels_value=False, predicted_value=True)
metrics["confusion_matrix/true_negatives"] = _count_condition(
"true_negatives", labels_value=False, predicted_value=False)
metrics["confusion_matrix/false_negatives"] = _count_condition(
"false_negatives", labels_value=True, predicted_value=False)
return metrics
......@@ -130,7 +128,12 @@ def create_metric_fn(model):
}
def metric_fn(labels, predictions, weights, batch_losses):
return _build_metrics(labels, predictions, weights, batch_losses)
return _build_metrics(
labels,
predictions,
weights,
batch_losses,
output_dim=model.hparams.output_dim)
return metric_fn, metric_fn_inputs
......
......@@ -30,15 +30,23 @@ def _unpack_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
class _MockHparams(object):
"""Mock Hparams class to support accessing with dot notation."""
pass
class _MockModel(object):
"""Mock model for testing."""
def __init__(self, labels, predictions, weights, batch_losses):
def __init__(self, labels, predictions, weights, batch_losses, output_dim):
self.labels = tf.constant(labels, dtype=tf.int32)
self.predictions = tf.constant(predictions, dtype=tf.float32)
self.weights = None if weights is None else tf.constant(
weights, dtype=tf.float32)
self.batch_losses = tf.constant(batch_losses, dtype=tf.float32)
self.hparams = _MockHparams()
self.hparams.output_dim = output_dim
class MetricsTest(tf.test.TestCase):
......@@ -48,13 +56,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -68,6 +76,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 1,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -76,6 +100,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 2,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testMultiClassificationWithWeights(self):
......@@ -83,13 +123,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -103,6 +143,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -111,6 +167,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testBinaryClassificationWithoutWeights(self):
......@@ -124,7 +196,7 @@ class MetricsTest(tf.test.TestCase):
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -139,10 +211,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 1,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 1,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 1,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -152,10 +224,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 2,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 2,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 2,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
def testBinaryClassificationWithWeights(self):
......@@ -169,7 +241,7 @@ class MetricsTest(tf.test.TestCase):
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -184,10 +256,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -197,10 +269,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
......
......@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features = {}
features["time_series_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if spec["is_time_series"]
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if not spec["is_time_series"]
}
features = {"time_series_features": {}, "aux_features": {}}
for name, spec in feature_spec.items():
ftype = "time_series_features" if spec["is_time_series"] else "aux_features"
features[ftype][name] = np.random.random([batch_size, spec["length"]])
return features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment