Unverified Commit 2c181308 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5862 from cshallue/master

Move tensorflow_models/research/astronet to google-research/exoplanet-ml
parents caafb6d1 62704f06
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to build an input pipeline that reads from TFRecord files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import six
import tensorflow as tf
def pad_tensor_to_batch_size(tensor, batch_size):
"""Pads a Tensor along the batch dimension to the desired batch size."""
if batch_size < 2:
raise ValueError("Cannot pad along batch dimension with batch_size < 2.")
ndims = len(tensor.shape)
if ndims < 1:
raise ValueError("Cannot pad a 0-dimensional Tensor")
num_pad_examples = batch_size - tf.shape(tensor)[0]
# paddings is a 2D Tensor with shape [ndims, 2]. Every element is zero except
# for paddings[0][1], which is the number of values to add along the 0-th
# dimension (the batch dimension) after the contents of the input tensor.
paddings = tf.sparse_to_dense(
sparse_indices=[[0, 1]],
output_shape=[ndims, 2],
sparse_values=num_pad_examples)
padded_tensor = tf.pad(tensor, paddings, name=tensor.op.name + "/pad")
# Set the new shape.
output_shape = tensor.shape.as_list()
output_shape[0] = batch_size
padded_tensor.set_shape(output_shape)
return padded_tensor
def _recursive_pad_to_batch_size(tensor_or_collection, batch_size):
"""Recursively pads to the batch size in a Tensor or collection of Tensors."""
if isinstance(tensor_or_collection, tf.Tensor):
return pad_tensor_to_batch_size(tensor_or_collection, batch_size)
if isinstance(tensor_or_collection, dict):
return {
name: _recursive_pad_to_batch_size(t, batch_size)
for name, t in tensor_or_collection.items()
}
if isinstance(tensor_or_collection, collections.Iterable):
return [
_recursive_pad_to_batch_size(t, batch_size)
for t in tensor_or_collection
]
raise ValueError("Unknown input type: {}".format(tensor_or_collection))
def pad_dataset_to_batch_size(dataset, batch_size):
"""Pads Tensors in a dataset along the batch dimension to batch_size.
The output contains a 'weights' Tensor, which is a 0/1 indicator of padded
elements. If a 'weights' Tensor already exists in the input dataset, then that
Tensor is padded with zeros. If a 'weights' Tensor does not already exist,
then the input dataset is assumed to have a 'labels' Tensor which is used to
construct the weights.
Args:
dataset: A tf.data.Dataset.
batch_size: Integer batch size.
Returns:
A tf.data.Dataset.
"""
def map_fn(tensors):
"""Pads Tensors along the batch dimension to the desired batch size."""
if not isinstance(tensors, dict):
raise ValueError(
"pad_dataset_to_batch_size requires a dictionary of named Tensors.")
outputs = _recursive_pad_to_batch_size(tensors, batch_size)
if "weights" not in outputs:
weights = tf.ones_like(tensors["labels"], dtype=tf.float32)
outputs["weights"] = pad_tensor_to_batch_size(weights, batch_size)
return outputs
return dataset.map(map_fn)
def _recursive_set_batch_size(tensor_or_collection, batch_size):
"""Recursively sets the batch size in a Tensor or collection of Tensors."""
if isinstance(tensor_or_collection, tf.Tensor):
t = tensor_or_collection
shape = t.shape.as_list()
shape[0] = batch_size
t.set_shape(t.shape.merge_with(shape))
elif isinstance(tensor_or_collection, dict):
for t in six.itervalues(tensor_or_collection):
_recursive_set_batch_size(t, batch_size)
elif isinstance(tensor_or_collection, collections.Iterable):
for t in tensor_or_collection:
_recursive_set_batch_size(t, batch_size)
else:
raise ValueError("Unknown input type: {}".format(tensor_or_collection))
return tensor_or_collection
def set_batch_size(dataset, batch_size):
"""Sets the batch dimension in all Tensors to batch_size."""
return dataset.map(lambda t: _recursive_set_batch_size(t, batch_size))
def build_dataset(file_pattern,
input_config,
batch_size,
include_labels=True,
reverse_time_series_prob=0,
shuffle_filenames=False,
shuffle_values_buffer=0,
repeat=1,
use_tpu=False):
"""Builds an input pipeline that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU.
Raises:
ValueError: If an input file pattern does not match any files, or if the
label IDs in input_config.label_map are not contiguous integers starting
at 0.
Returns:
A tf.data.Dataset object.
"""
file_patterns = file_pattern.split(",")
filenames = []
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching {}".format(p))
filenames.extend(matches)
tf.logging.info("Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
if include_labels:
# Ensure that the label ids are contiguous integers starting at 0.
label_ids = set(input_config.label_map.values())
if label_ids != set(range(len(label_ids))):
raise ValueError(
"Label IDs must be contiguous integers starting at 0. Got: {}".format(
label_ids))
# Create a HashTable mapping label strings to integer ids.
table_initializer = tf.contrib.lookup.KeyValueTensorInitializer(
keys=list(input_config.label_map.keys()),
values=list(input_config.label_map.values()),
key_dtype=tf.string,
value_dtype=tf.int32)
label_to_id = tf.contrib.lookup.HashTable(
table_initializer, default_value=-1)
def _example_parser(serialized_example):
"""Parses a single tf.Example into feature and label tensors."""
# Set specifications for parsing the features.
data_fields = {
feature_name: tf.FixedLenFeature([feature.length], tf.float32)
for feature_name, feature in input_config.features.items()
}
if include_labels:
data_fields[input_config.label_feature] = tf.FixedLenFeature([],
tf.string)
# Parse the features.
parsed_features = tf.parse_single_example(
serialized_example, features=data_fields)
if reverse_time_series_prob > 0:
# Randomly reverse time series features with probability
# reverse_time_series_prob.
should_reverse = tf.less(
tf.random_uniform([], 0, 1),
reverse_time_series_prob,
name="should_reverse")
# Reorganize outputs.
output = {}
for feature_name, value in parsed_features.items():
if include_labels and feature_name == input_config.label_feature:
label_id = label_to_id.lookup(value)
# Ensure that the label_id is nonnegative to verify a successful hash
# map lookup.
assert_known_label = tf.Assert(
tf.greater_equal(label_id, tf.to_int32(0)),
["Unknown label string:", value])
with tf.control_dependencies([assert_known_label]):
label_id = tf.identity(label_id)
# We use the plural name "labels" in the output due to batching.
output["labels"] = label_id
elif input_config.features[feature_name].is_time_series:
# Possibly reverse.
if reverse_time_series_prob > 0:
# pylint:disable=cell-var-from-loop
value = tf.cond(should_reverse, lambda: tf.reverse(value, axis=[0]),
lambda: tf.identity(value))
# pylint:enable=cell-var-from-loop
if "time_series_features" not in output:
output["time_series_features"] = {}
output["time_series_features"][feature_name] = value
else:
if "aux_features" not in output:
output["aux_features"] = {}
output["aux_features"][feature_name] = value
return output
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if len(filenames) > 1 and shuffle_filenames:
filename_dataset = filename_dataset.shuffle(len(filenames))
# Read serialized Example protos.
dataset = filename_dataset.flat_map(tf.data.TFRecordDataset)
# Possibly shuffle. Note that we shuffle before repeat(), so we only shuffle
# elements among each "epoch" of data, and not across epochs of data.
if shuffle_values_buffer > 0:
dataset = dataset.shuffle(shuffle_values_buffer)
# Repeat.
if repeat != 1:
dataset = dataset.repeat(repeat)
# Map the parser over the dataset.
dataset = dataset.map(_example_parser, num_parallel_calls=4)
# Batch results by up to batch_size.
dataset = dataset.batch(batch_size)
if repeat == -1 or repeat is None:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset = set_batch_size(dataset, batch_size)
elif use_tpu:
# TPU requires all dimensions to be fixed. Since the dataset does not repeat
# infinitely before batching, the final batch may have fewer than batch_size
# elements. Therefore we pad to ensure that the final batch has batch_size
# elements.
dataset = pad_dataset_to_batch_size(dataset, batch_size)
# Prefetch a few batches.
dataset = dataset.prefetch(max(1, int(256 / batch_size)))
return dataset
This diff is collapsed.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operations for feeding input data using TensorFlow placeholders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def prepare_feed_dict(model, features, labels=None, is_training=None):
"""Prepares a feed_dict for sess.run() given a batch of features and labels.
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
"""
feed_dict = {}
for feature, tensor in model.time_series_features.items():
feed_dict[tensor] = features["time_series_features"][feature]
for feature, tensor in model.aux_features.items():
feed_dict[tensor] = features["aux_features"][feature]
if labels is not None:
feed_dict[model.labels] = labels
if is_training is not None:
feed_dict[model.is_training] = is_training
return feed_dict
def build_feature_placeholders(config):
"""Builds tf.Placeholder ops for feeding model features and labels.
Args:
config: ConfigDict containing the feature configurations.
Returns:
features: A dictionary containing "time_series_features" and "aux_features",
each of which is a dictionary of tf.Placeholders of features from the
input configuration. All features have dtype float32 and shape
[batch_size, length].
"""
batch_size = None # Batch size will be dynamically specified.
features = {"time_series_features": {}, "aux_features": {}}
for feature_name, feature_spec in config.items():
placeholder = tf.placeholder(
dtype=tf.float32,
shape=[batch_size, feature_spec.length],
name=feature_name)
if feature_spec.is_time_series:
features["time_series_features"][feature_name] = placeholder
else:
features["aux_features"][feature_name] = placeholder
return features
def build_labels_placeholder():
"""Builds a tf.Placeholder op for feeding model labels.
Returns:
labels: An int64 tf.Placeholder with shape [batch_size].
"""
batch_size = None # Batch size will be dynamically specified.
return tf.placeholder(dtype=tf.int64, shape=[batch_size], name="labels")
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for input_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astronet.ops import input_ops
from tf_util import configdict
class InputOpsTest(tf.test.TestCase):
def assertFeatureShapesEqual(self, expected_shapes, features):
"""Asserts that a dict of feature placeholders has the expected shapes.
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
"""
actual_shapes = {}
for feature_type in features:
actual_shapes[feature_type] = {
feature: tensor.shape.as_list()
for feature, tensor in features[feature_type].items()
}
self.assertDictEqual(expected_shapes, actual_shapes)
def testBuildFeaturePlaceholders(self):
# One time series feature.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
}
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
},
"aux_features": {}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
# Two time series features.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
},
"time_feature_2": {
"length": 5,
"is_time_series": True,
}
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
"time_feature_2": [None, 5],
},
"aux_features": {}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
# One aux feature.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
},
"aux_feature_1": {
"length": 1,
"is_time_series": False,
}
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
},
"aux_features": {
"aux_feature_1": [None, 1]
}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
# Two aux features.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
},
"aux_feature_1": {
"length": 1,
"is_time_series": False,
},
"aux_feature_2": {
"length": 6,
"is_time_series": False,
},
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
},
"aux_features": {
"aux_feature_1": [None, 1],
"aux_feature_2": [None, 6]
}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for computing evaluation metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def _metric_variable(name, shape, dtype):
"""Creates a Variable in LOCAL_VARIABLES and METRIC_VARIABLES collections."""
return tf.get_variable(
name,
initializer=tf.zeros(shape, dtype),
trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES])
def _build_metrics(labels, predictions, weights, batch_losses, output_dim=1):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
labels: Tensor with shape [batch_size].
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert len(predictions.shape) == 2
binary_classification = output_dim == 1
if binary_classification:
assert predictions.shape[1] == 1
predictions = tf.squeeze(predictions, axis=[1])
predicted_labels = tf.to_int32(
tf.greater(predictions, 0.5), name="predicted_labels")
else:
predicted_labels = tf.argmax(
predictions, 1, name="predicted_labels", output_type=tf.int32)
metrics = {}
with tf.variable_scope("metrics"):
# Total number of examples.
num_examples = _metric_variable("num_examples", [], tf.float32)
update_num_examples = tf.assign_add(num_examples, tf.reduce_sum(weights))
metrics["num_examples"] = (num_examples.read_value(), update_num_examples)
# Accuracy metrics.
num_correct = _metric_variable("num_correct", [], tf.float32)
is_correct = weights * tf.to_float(tf.equal(labels, predicted_labels))
update_num_correct = tf.assign_add(num_correct, tf.reduce_sum(is_correct))
metrics["accuracy/num_correct"] = (num_correct.read_value(),
update_num_correct)
accuracy = tf.div(num_correct, num_examples, name="accuracy")
metrics["accuracy/accuracy"] = (accuracy, tf.no_op())
# Weighted cross-entropy loss.
metrics["losses/weighted_cross_entropy"] = tf.metrics.mean(
batch_losses, weights=weights, name="cross_entropy_loss")
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
num_labels = 2 if binary_classification else output_dim
for gold_label in range(num_labels):
for pred_label in range(num_labels):
metric_name = "confusion_matrix/label_{}_pred_{}".format(
gold_label, pred_label)
metrics[metric_name] = _count_condition(
metric_name, labels_value=gold_label, predicted_value=pred_label)
# Possibly create AUC metric for binary classification.
if binary_classification:
labels = tf.cast(labels, dtype=tf.bool)
metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000)
return metrics
def create_metric_fn(model):
"""Creates a tuple (metric_fn, metric_fn_inputs).
This function is primarily used for creating a TPUEstimator.
The result of calling metric_fn(**metric_fn_inputs) is a dictionary
{metric_name: (metric_value, update_op)}.
Args:
model: Instance of AstroModel.
Returns:
A tuple (metric_fn, metric_fn_inputs).
"""
weights = model.weights
if weights is None:
weights = tf.ones_like(model.labels, dtype=tf.float32)
metric_fn_inputs = {
"labels": model.labels,
"predictions": model.predictions,
"weights": weights,
"batch_losses": model.batch_losses,
}
def metric_fn(labels, predictions, weights, batch_losses):
return _build_metrics(
labels,
predictions,
weights,
batch_losses,
output_dim=model.hparams.output_dim)
return metric_fn, metric_fn_inputs
def create_metrics(model):
"""Creates a dictionary {metric_name: (metric_value, update_op)}.
This function is primarily used for creating an Estimator.
Args:
model: Instance of AstroModel.
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
metric_fn, metric_fn_inputs = create_metric_fn(model)
return metric_fn(**metric_fn_inputs)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astronet.ops import metrics
def _unpack_metric_map(names_to_tuples):
"""Unpacks {metric_name: (metric_value, update_op)} into separate dicts."""
metric_names = names_to_tuples.keys()
value_ops, update_ops = zip(*names_to_tuples.values())
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
class _MockHparams(object):
"""Mock Hparams class to support accessing with dot notation."""
pass
class _MockModel(object):
"""Mock model for testing."""
def __init__(self, labels, predictions, weights, batch_losses, output_dim):
self.labels = tf.constant(labels, dtype=tf.int32)
self.predictions = tf.constant(predictions, dtype=tf.float32)
self.weights = None if weights is None else tf.constant(
weights, dtype=tf.float32)
self.batch_losses = tf.constant(batch_losses, dtype=tf.float32)
self.hparams = _MockHparams()
self.hparams.output_dim = output_dim
class MetricsTest(tf.test.TestCase):
def testMultiClassificationWithoutWeights(self):
labels = [0, 1, 2, 3]
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 1,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 8,
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 2,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testMultiClassificationWithWeights(self):
labels = [0, 1, 2, 3]
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 2,
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testBinaryClassificationWithoutWeights(self):
labels = [0, 1, 1, 0]
predictions = [
[0.4], # Predicted label = 0
[0.6], # Predicted label = 1
[0.0], # Predicted label = 0
[1.0], # Predicted label = 1
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 1,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 8,
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 2,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
def testBinaryClassificationWithWeights(self):
labels = [0, 1, 1, 0]
predictions = [
[0.4], # Predicted label = 0
[0.6], # Predicted label = 1
[0.0], # Predicted label = 0
[1.0], # Predicted label = 1
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 2,
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
if __name__ == "__main__":
tf.test.main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
srcs_version = "PY2AND3",
deps = [
"//astronet/ops:dataset_ops",
"//astronet/ops:metrics",
"//astronet/ops:training",
],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment