Unverified Commit 6571d16d authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3544 from cshallue/master

Add AstroNet to tensorflow/models
parents 92083555 6c891bc3
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_binary(
name = "generate_input_records",
srcs = ["generate_input_records.py"],
deps = [":preprocess"],
)
py_binary(
name = "preprocess",
srcs = ["preprocess.py"],
deps = [
"//light_curve_util:kepler_io",
"//light_curve_util:median_filter",
"//light_curve_util:util",
"//third_party/kepler_spline",
],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Generates a bash script for downloading light curves.
The input to this script is a CSV file of Kepler targets, for example the DR24
TCE table, which can be downloaded in CSV format from the NASA Exoplanet Archive
at:
https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/nph-tblView?app=ExoTbls&config=q1_q17_dr24_tce
Example usage:
python generate_download_script.py \
--kepler_csv_file=dr24_tce.csv \
--download_dir=${HOME}/astronet/kepler
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import csv
import os
import sys
parser = argparse.ArgumentParser()
parser.add_argument(
"--kepler_csv_file",
type=str,
required=True,
help="CSV file containing Kepler targets to download. Must contain a "
"'kepid' column.")
parser.add_argument(
"--download_dir",
type=str,
required=True,
help="Directory into which the Kepler data will be downloaded.")
parser.add_argument(
"--output_file",
type=str,
default="get_kepler.sh",
help="Filename of the output download script.")
_WGET_CMD = ("wget -q -nH --cut-dirs=6 -r -l0 -c -N -np -erobots=off "
"-R 'index*' -A _llc.fits")
_BASE_URL = "http://archive.stsci.edu/pub/kepler/lightcurves"
def main(argv):
del argv # Unused.
# Read Kepler targets.
kepids = set()
with open(FLAGS.kepler_csv_file) as f:
reader = csv.DictReader(row for row in f if not row.startswith("#"))
for row in reader:
kepids.add(row["kepid"])
num_kepids = len(kepids)
# Write wget commands to script file.
with open(FLAGS.output_file, "w") as f:
f.write("#!/bin/sh\n")
f.write("echo 'Downloading {} Kepler targets to {}'\n".format(
num_kepids, FLAGS.download_dir))
for i, kepid in enumerate(kepids):
if i and not i % 10:
f.write("echo 'Downloaded {}/{}'\n".format(i, num_kepids))
kepid = "{0:09d}".format(int(kepid)) # Pad with zeros.
subdir = "{}/{}".format(kepid[0:4], kepid)
download_dir = os.path.join(FLAGS.download_dir, subdir)
url = "{}/{}/".format(_BASE_URL, subdir)
f.write("{} -P {} {}\n".format(_WGET_CMD, download_dir, url))
f.write("echo 'Finished downloading {} Kepler targets to {}'\n".format(
num_kepids, FLAGS.download_dir))
os.chmod(FLAGS.output_file, 0744) # Make the download script executable.
print("{} Kepler targets will be downloaded to {}".format(
num_kepids, FLAGS.output_file))
print("To start download, run:\n {}".format("./" + FLAGS.output_file
if "/" not in FLAGS.output_file
else FLAGS.output_file))
if __name__ == "__main__":
FLAGS, unparsed = parser.parse_known_args()
main(argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script to preprocesses data from the Kepler space telescope.
This script produces training, validation and test sets of labeled Kepler
Threshold Crossing Events (TCEs). A TCE is a detected periodic event on a
particular Kepler target star that may or may not be a transiting planet. Each
TCE in the output contains local and global views of its light curve; auxiliary
features such as period and duration; and a label indicating whether the TCE is
consistent with being a transiting planet. The data sets produced by this script
can be used to train and evaluate models that classify Kepler TCEs.
The input TCEs and their associated labels are specified by the DR24 TCE Table,
which can be downloaded in CSV format from the NASA Exoplanet Archive at:
https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/nph-tblView?app=ExoTbls&config=q1_q17_dr24_tce
The downloaded CSV file should contain at least the following column names:
rowid: Integer ID of the row in the TCE table.
kepid: Kepler ID of the target star.
tce_plnt_num: TCE number within the target star.
tce_period: Orbital period of the detected event, in days.
tce_time0bk: The time corresponding to the center of the first detected
traisit in Barycentric Julian Day (BJD) minus a constant offset of
2,454,833.0 days.
tce_duration: Duration of the detected transit, in hours.
av_training_set: Autovetter training set label; one of PC (planet candidate),
AFP (astrophysical false positive), NTP (non-transiting phenomenon),
UNK (unknown).
The Kepler light curves can be downloaded from the Mikulski Archive for Space
Telescopes (MAST) at:
http://archive.stsci.edu/pub/kepler/lightcurves.
The Kepler data is assumed to reside in a directory with the same structure as
the MAST archive. Specifically, the file names for a particular Kepler target
star should have the following format:
.../${kep_id:0:4}/${kep_id}/kplr${kep_id}-${quarter_prefix}_${type}.fits,
where:
kep_id is the Kepler id left-padded with zeros to length 9;
quarter_prefix is the file name quarter prefix;
type is one of "llc" (long cadence light curve) or "slc" (short cadence light
curve).
The output TFRecord file contains one serialized tensorflow.train.Example
protocol buffer for each TCE in the input CSV file. Each Example contains the
following light curve representations:
global_view: Vector of length 2001; the Global View of the TCE.
local_view: Vector of length 201; the Local View of the TCE.
In addition, each Example contains the value of each column in the input TCE CSV
file. Some of these features may be useful as auxiliary features to the model.
The columns include:
rowid: Integer ID of the row in the TCE table.
kepid: Kepler ID of the target star.
tce_plnt_num: TCE number within the target star.
av_training_set: Autovetter training set label.
tce_period: Orbital period of the detected event, in days.
...
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import multiprocessing
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
from astronet.data import preprocess
parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
"nph-tblView?app=ExoTbls&config=q1_q17_dr24_tce")
parser.add_argument(
"--input_tce_csv_file",
type=str,
required=True,
help="CSV file containing the Q1-Q17 DR24 Kepler TCE table. Must contain "
"columns: rowid, kepid, tce_plnt_num, tce_period, tce_duration, "
"tce_time0bk. Download from: %s" % _DR24_TCE_URL)
parser.add_argument(
"--kepler_data_dir",
type=str,
required=True,
help="Base folder containing Kepler data.")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Directory in which to save the output.")
parser.add_argument(
"--num_train_shards",
type=int,
default=8,
help="Number of file shards to divide the training set into.")
parser.add_argument(
"--num_worker_processes",
type=int,
default=5,
help="Number of subprocesses for processing the TCEs in parallel.")
# Name and values of the column in the input CSV file to use as training labels.
_LABEL_COLUMN = "av_training_set"
_ALLOWED_LABELS = {"PC", "AFP", "NTP"}
def _set_float_feature(ex, name, value):
"""Sets the value of a float feature in a tensorflow.train.Example proto."""
assert name not in ex.features.feature, "Duplicate feature: %s" % name
ex.features.feature[name].float_list.value.extend([float(v) for v in value])
def _set_bytes_feature(ex, name, value):
"""Sets the value of a bytes feature in a tensorflow.train.Example proto."""
assert name not in ex.features.feature, "Duplicate feature: %s" % name
ex.features.feature[name].bytes_list.value.extend([str(v) for v in value])
def _set_int64_feature(ex, name, value):
"""Sets the value of an int64 feature in a tensorflow.train.Example proto."""
assert name not in ex.features.feature, "Duplicate feature: %s" % name
ex.features.feature[name].int64_list.value.extend([int(v) for v in value])
def _process_tce(tce):
"""Processes the light curve for a Kepler TCE and returns an Example proto.
Args:
tce: Row of the input TCE table.
Returns:
A tensorflow.train.Example proto containing TCE features.
Raises:
IOError: If the light curve files for this Kepler ID cannot be found.
"""
# Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(tce.kepid,
FLAGS.kepler_data_dir)
time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, tce.tce_period, tce.tce_time0bk)
# Generate the local and global views.
global_view = preprocess.global_view(time, flux, tce.tce_period)
local_view = preprocess.local_view(time, flux, tce.tce_period,
tce.tce_duration)
# Make output proto.
ex = tf.train.Example()
# Set time series features.
_set_float_feature(ex, "global_view", global_view)
_set_float_feature(ex, "local_view", local_view)
# Set other columns.
for col_name, value in tce.iteritems():
if np.issubdtype(type(value), np.integer):
_set_int64_feature(ex, col_name, [value])
else:
try:
_set_float_feature(ex, col_name, [float(value)])
except ValueError:
_set_bytes_feature(ex, col_name, [str(value)])
return ex
def _process_file_shard(tce_table, file_name):
"""Processes a single file shard.
Args:
tce_table: A Pandas DateFrame containing the TCEs in the shard.
file_name: The output TFRecord file.
"""
process_name = multiprocessing.current_process().name
shard_name = os.path.basename(file_name)
shard_size = len(tce_table)
tf.logging.info("%s: Processing %d items in shard %s", process_name,
shard_size, shard_name)
with tf.python_io.TFRecordWriter(file_name) as writer:
num_processed = 0
for _, tce in tce_table.iterrows():
example = _process_tce(tce)
if example is not None:
writer.write(example.SerializeToString())
num_processed += 1
if not num_processed % 10:
tf.logging.info("%s: Processed %d/%d items in shard %s", process_name,
num_processed, shard_size, shard_name)
tf.logging.info("%s: Wrote %d items in shard %s", process_name, shard_size,
shard_name)
def main(argv):
del argv # Unused.
# Make the output directory if it doesn't already exist.
tf.gfile.MakeDirs(FLAGS.output_dir)
# Read CSV file of Kepler KOIs.
tce_table = pd.read_csv(
FLAGS.input_tce_csv_file, index_col="rowid", comment="#")
tce_table["tce_duration"] /= 24 # Convert hours to days.
tf.logging.info("Read TCE CSV file with %d rows.", len(tce_table))
# Filter TCE table to allowed labels.
allowed_tces = tce_table[_LABEL_COLUMN].apply(lambda l: l in _ALLOWED_LABELS)
tce_table = tce_table[allowed_tces]
num_tces = len(tce_table)
tf.logging.info("Filtered to %d TCEs with labels in %s.", num_tces,
list(_ALLOWED_LABELS))
# Randomly shuffle the TCE table.
np.random.seed(123)
tce_table = tce_table.iloc[np.random.permutation(num_tces)]
tf.logging.info("Randomly shuffled TCEs.")
# Partition the TCE table as follows:
# train_tces = 80% of TCEs
# val_tces = 10% of TCEs (for validation during training)
# test_tces = 10% of TCEs (for final evaluation)
train_cutoff = int(0.80 * num_tces)
val_cutoff = int(0.90 * num_tces)
train_tces = tce_table[0:train_cutoff]
val_tces = tce_table[train_cutoff:val_cutoff]
test_tces = tce_table[val_cutoff:]
tf.logging.info(
"Partitioned %d TCEs into training (%d), validation (%d) and test (%d)",
num_tces, len(train_tces), len(val_tces), len(test_tces))
# Further split training TCEs into file shards.
file_shards = [] # List of (tce_table_shard, file_name).
boundaries = np.linspace(0, len(train_tces),
FLAGS.num_train_shards + 1).astype(np.int)
for i in range(FLAGS.num_train_shards):
start = boundaries[i]
end = boundaries[i + 1]
file_shards.append((train_tces[start:end], os.path.join(
FLAGS.output_dir, "train-%.5d-of-%.5d" % (i, FLAGS.num_train_shards))))
# Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir,
"val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir,
"test-00000-of-00001")))
num_file_shards = len(file_shards)
# Launch subprocesses for the file shards.
num_processes = min(num_file_shards, FLAGS.num_worker_processes)
tf.logging.info("Launching %d subprocesses for %d total file shards",
num_processes, num_file_shards)
pool = multiprocessing.Pool(processes=num_processes)
async_results = [
pool.apply_async(_process_file_shard, file_shard)
for file_shard in file_shards
]
pool.close()
# Instead of pool.join(), we call async_result.get() to ensure any exceptions
# raised by the worker processes are also raised here.
for async_result in async_results:
async_result.get()
tf.logging.info("Finished processing %d total file shards", num_file_shards)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for reading and preprocessing light curves."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from light_curve_util import kepler_io
from light_curve_util import median_filter
from light_curve_util import util
from third_party.kepler_spline import kepler_spline
def read_and_process_light_curve(kepid, kepler_data_dir, max_gap_width=0.75):
"""Reads a light curve, fits a B-spline and divides the curve by the spline.
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
max_gap_width: Gap size (in days) above which the light curve is split for
the fitting of B-splines.
Returns:
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
Raises:
IOError: If the light curve files for this Kepler ID cannot be found.
ValueError: If the spline could not be fit.
"""
# Read the Kepler light curve.
file_names = kepler_io.kepler_filenames(kepler_data_dir, kepid)
if not file_names:
raise IOError("Failed to find .fits files in %s for Kepler ID %s" %
(kepler_data_dir, kepid))
all_time, all_flux = kepler_io.read_kepler_light_curve(file_names)
# Split on gaps.
all_time, all_flux = util.split(all_time, all_flux, gap_width=max_gap_width)
# Logarithmically sample candidate break point spacings between 0.5 and 20
# days.
bkspaces = np.logspace(np.log10(0.5), np.log10(20), num=20)
# Generate spline.
spline = kepler_spline.choose_kepler_spline(
all_time, all_flux, bkspaces, penalty_coeff=1.0, verbose=False)[0]
if spline is None:
raise ValueError("Failed to fit spline with Kepler ID %s", kepid)
# Concatenate the piecewise light curve and spline.
time = np.concatenate(all_time)
flux = np.concatenate(all_flux)
spline = np.concatenate(spline)
# In rare cases the piecewise spline contains NaNs in places the spline could
# not be fit. We can't normalize those points if the spline isn't defined
# there. Instead we just remove them.
finite_i = np.isfinite(spline)
if not np.all(finite_i):
tf.logging.warn("Incomplete spline with Kepler ID %s", kepid)
time = time[finite_i]
flux = flux[finite_i]
spline = spline[finite_i]
# "Flatten" the light curve (remove low-frequency variability) by dividing by
# the spline.
flux /= spline
return time, flux
def phase_fold_and_sort_light_curve(time, flux, period, t0):
"""Phase folds a light curve and sorts by ascending time.
Args:
time: 1D NumPy array of time values.
flux: 1D NumPy array of flux values.
period: A positive real scalar; the period to fold over.
t0: The center of the resulting folded vector; this value is mapped to 0.
Returns:
folded_time: 1D NumPy array of phase folded time values in
[-period / 2, period / 2), where 0 corresponds to t0 in the original
time array. Values are sorted in ascending order.
folded_flux: 1D NumPy array. Values are the same as the original input
array, but sorted by folded_time.
"""
# Phase fold time.
time = util.phase_fold_time(time, period, t0)
# Sort by ascending time.
sorted_i = np.argsort(time)
time = time[sorted_i]
flux = flux[sorted_i]
return time, flux
def generate_view(time, flux, num_bins, bin_width, t_min, t_max,
normalize=True):
"""Generates a view of a phase-folded light curve using a median filter.
Args:
time: 1D array of time values, sorted in ascending order.
flux: 1D array of flux values.
num_bins: The number of intervals to divide the time axis into.
bin_width: The width of each bin on the time axis.
t_min: The inclusive leftmost value to consider on the time axis.
t_max: The exclusive rightmost value to consider on the time axis.
normalize: Whether to center the median at 0 and minimum value at -1.
Returns:
1D NumPy array of size num_bins containing the median flux values of
uniformly spaced bins on the phase-folded time axis.
"""
view = median_filter.median_filter(time, flux, num_bins, bin_width, t_min,
t_max)
if normalize:
view -= np.median(view)
view /= np.abs(np.min(view))
return view
def global_view(time, flux, period, num_bins=2001, bin_width_factor=1 / 2001):
"""Generates a 'global view' of a phase folded light curve.
See Section 3.3 of Shallue & Vanderburg, 2018, The Astronomical Journal.
http://iopscience.iop.org/article/10.3847/1538-3881/aa9e09/meta
Args:
time: 1D array of time values, sorted in ascending order.
flux: 1D array of flux values.
period: The period of the event (in days).
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of period.
Returns:
1D NumPy array of size num_bins containing the median flux values of
uniformly spaced bins on the phase-folded time axis.
"""
return generate_view(
time,
flux,
num_bins=num_bins,
bin_width=period * bin_width_factor,
t_min=-period / 2,
t_max=period / 2)
def local_view(time,
flux,
period,
duration,
num_bins=201,
bin_width_factor=0.16,
num_durations=4):
"""Generates a 'local view' of a phase folded light curve.
See Section 3.3 of Shallue & Vanderburg, 2018, The Astronomical Journal.
http://iopscience.iop.org/article/10.3847/1538-3881/aa9e09/meta
Args:
time: 1D array of time values, sorted in ascending order.
flux: 1D array of flux values.
period: The period of the event (in days).
duration: The duration of the event (in days).
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0).
Returns:
1D NumPy array of size num_bins containing the median flux values of
uniformly spaced bins on the phase-folded time axis.
"""
return generate_view(
time,
flux,
num_bins=num_bins,
bin_width=duration * bin_width_factor,
t_min=max(-period / 2, -duration * num_durations),
t_max=min(period / 2, duration * num_durations))
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for evaluating an AstroNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, required=True, help="Name of the model class.")
parser.add_argument(
"--config_name",
type=str,
help="Name of the model and training configuration. Exactly one of "
"--config_name or --config_json is required.")
parser.add_argument(
"--config_json",
type=str,
help="JSON string or JSON file containing the model and training "
"configuration. Exactly one of --config_name or --config_json is required.")
parser.add_argument(
"--eval_files",
type=str,
required=True,
help="Comma-separated list of file patterns matching the TFRecord files in "
"the evaluation dataset.")
parser.add_argument(
"--model_dir",
type=str,
required=True,
help="Directory containing a model checkpoint.")
parser.add_argument(
"--eval_name", type=str, default="test", help="Name of the evaluation set.")
def main(_):
model_class = models.get_model_class(FLAGS.model)
# Look up the model configuration.
assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
"Exactly one of --config_name or --config_json is required.")
config = (
models.get_model_config(FLAGS.model, FLAGS.config_name)
if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
config = configdict.ConfigDict(config)
# Create the estimator.
estimator = estimator_util.create_estimator(
model_class, config.hparams, model_dir=FLAGS.model_dir)
# Create an input function that reads the evaluation dataset.
input_fn = estimator_util.create_input_fn(
file_pattern=FLAGS.eval_files,
input_config=config.inputs,
mode=tf.estimator.ModeKeys.EVAL)
# Run evaluation. This will log the result to stderr and also write a summary
# file in the model_dir.
estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library of AstroNet models and configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from astronet.astro_cnn_model import astro_cnn_model
from astronet.astro_cnn_model import configurations as astro_cnn_configurations
from astronet.astro_fc_model import astro_fc_model
from astronet.astro_fc_model import configurations as astro_fc_configurations
from astronet.astro_model import astro_model
from astronet.astro_model import configurations as astro_configurations
# Dictionary of model name to (model_class, configuration_module).
_MODELS = {
"AstroModel": (astro_model.AstroModel, astro_configurations),
"AstroFCModel": (astro_fc_model.AstroFCModel, astro_fc_configurations),
"AstroCNNModel": (astro_cnn_model.AstroCNNModel, astro_cnn_configurations),
}
def get_model_class(model_name):
"""Looks up a model class by name.
Args:
model_name: Name of the model class.
Returns:
model_class: The requested model class.
Raises:
ValueError: If model_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: %s" % model_name)
return _MODELS[model_name][0]
def get_model_config(model_name, config_name):
"""Looks up a model configuration by name.
Args:
model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's
configurations module.
Returns:
model_class: The requested model class.
config: The requested configuration.
Raises:
ValueError: If model_name or config_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: %s" % model_name)
config_module = _MODELS[model_name][1]
try:
return getattr(config_module, config_name)()
except AttributeError:
raise ValueError("Config name '%s' not found in configuration module: %s" %
(config_name, config_module.__name__))
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "input_ops",
srcs = ["input_ops.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "input_ops_test",
size = "small",
srcs = ["input_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":input_ops",
"//astronet/util:configdict",
],
)
py_library(
name = "dataset_ops",
srcs = ["dataset_ops.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "dataset_ops_test",
size = "small",
srcs = ["dataset_ops_test.py"],
data = ["test_data/test_dataset.tfrecord"],
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
"//astronet/util:configdict",
],
)
py_library(
name = "training",
srcs = ["training.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "testing",
srcs = ["testing.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "metrics",
srcs = ["metrics.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "metrics_test",
size = "small",
srcs = ["metrics_test.py"],
srcs_version = "PY2AND3",
deps = [":metrics"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to build an input pipeline that reads from TFRecord files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import six
import tensorflow as tf
def pad_tensor_to_batch_size(tensor, batch_size):
"""Pads a Tensor along the batch dimension to the desired batch size."""
if batch_size < 2:
raise ValueError("Cannot pad along batch dimension with batch_size < 2.")
ndims = len(tensor.shape)
if ndims < 1:
raise ValueError("Cannot pad a 0-dimensional Tensor")
num_pad_examples = batch_size - tf.shape(tensor)[0]
# paddings is a 2D Tensor with shape [ndims, 2]. Every element is zero except
# for paddings[0][1], which is the number of values to add along the 0-th
# dimension (the batch dimension) after the contents of the input tensor.
paddings = tf.sparse_to_dense(
sparse_indices=[[0, 1]],
output_shape=[ndims, 2],
sparse_values=num_pad_examples)
padded_tensor = tf.pad(tensor, paddings, name=tensor.op.name + "/pad")
# Set the new shape.
output_shape = tensor.shape.as_list()
output_shape[0] = batch_size
padded_tensor.set_shape(output_shape)
return padded_tensor
def _recursive_pad_to_batch_size(tensor_or_collection, batch_size):
"""Recursively pads to the batch size in a Tensor or collection of Tensors."""
if isinstance(tensor_or_collection, tf.Tensor):
return pad_tensor_to_batch_size(tensor_or_collection, batch_size)
if isinstance(tensor_or_collection, dict):
return {
name: _recursive_pad_to_batch_size(t, batch_size)
for name, t in tensor_or_collection.iteritems()
}
if isinstance(tensor_or_collection, collections.Iterable):
return [
_recursive_pad_to_batch_size(t, batch_size)
for t in tensor_or_collection
]
raise ValueError("Unknown input type: %s" % tensor_or_collection)
def pad_dataset_to_batch_size(dataset, batch_size):
"""Pads Tensors in a dataset along the batch dimension to batch_size.
The output contains a 'weights' Tensor, which is a 0/1 indicator of padded
elements. If a 'weights' Tensor already exists in the input dataset, then that
Tensor is padded with zeros. If a 'weights' Tensor does not already exist,
then the input dataset is assumed to have a 'labels' Tensor which is used to
construct the weights.
Args:
dataset: A tf.data.Dataset.
batch_size: Integer batch size.
Returns:
A tf.data.Dataset.
"""
def map_fn(tensors):
"""Pads Tensors along the batch dimension to the desired batch size."""
if not isinstance(tensors, dict):
raise ValueError(
"pad_dataset_to_batch_size requires a dictionary of named Tensors.")
outputs = _recursive_pad_to_batch_size(tensors, batch_size)
if "weights" not in outputs:
weights = tf.ones_like(tensors["labels"], dtype=tf.float32)
outputs["weights"] = pad_tensor_to_batch_size(weights, batch_size)
return outputs
return dataset.map(map_fn)
def _recursive_set_batch_size(tensor_or_collection, batch_size):
"""Recursively sets the batch size in a Tensor or collection of Tensors."""
if isinstance(tensor_or_collection, tf.Tensor):
t = tensor_or_collection
shape = t.shape.as_list()
shape[0] = batch_size
t.set_shape(t.shape.merge_with(shape))
elif isinstance(tensor_or_collection, dict):
for t in six.itervalues(tensor_or_collection):
_recursive_set_batch_size(t, batch_size)
elif isinstance(tensor_or_collection, collections.Iterable):
for t in tensor_or_collection:
_recursive_set_batch_size(t, batch_size)
else:
raise ValueError("Unknown input type: %s" % tensor_or_collection)
return tensor_or_collection
def set_batch_size(dataset, batch_size):
"""Sets the batch dimension in all Tensors to batch_size."""
return dataset.map(lambda t: _recursive_set_batch_size(t, batch_size))
def build_dataset(file_pattern,
input_config,
batch_size,
include_labels=True,
reverse_time_series_prob=0,
shuffle_filenames=False,
shuffle_values_buffer=0,
repeat=1,
use_tpu=False):
"""Builds an input pipeline that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU.
Raises:
ValueError: If an input file pattern does not match any files, or if the
label IDs in input_config.label_map are not contiguous integers starting
at 0.
Returns:
A tf.data.Dataset object.
"""
file_patterns = file_pattern.split(",")
filenames = []
for p in file_patterns:
matches = tf.gfile.Glob(p)
if not matches:
raise ValueError("Found no input files matching %s" % p)
filenames.extend(matches)
tf.logging.info("Building input pipeline from %d files matching patterns: %s",
len(filenames), file_patterns)
if include_labels:
# Ensure that the label ids are contiguous integers starting at 0.
label_ids = set(input_config.label_map.values())
if label_ids != set(range(len(label_ids))):
raise ValueError(
"Label IDs must be contiguous integers starting at 0. Got: %s" %
label_ids)
# Create a HashTable mapping label strings to integer ids.
table_initializer = tf.contrib.lookup.KeyValueTensorInitializer(
keys=input_config.label_map.keys(),
values=input_config.label_map.values(),
key_dtype=tf.string,
value_dtype=tf.int32)
label_to_id = tf.contrib.lookup.HashTable(
table_initializer, default_value=-1)
def _example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
# Set specifications for parsing the features.
data_fields = {
feature_name: tf.FixedLenFeature([feature.length], tf.float32)
for feature_name, feature in input_config.features.iteritems()
}
if include_labels:
data_fields[input_config.label_feature] = tf.FixedLenFeature([],
tf.string)
# Parse the features.
parsed_features = tf.parse_single_example(
serialized_example, features=data_fields)
if reverse_time_series_prob > 0:
# Randomly reverse time series features with probability
# reverse_time_series_prob.
should_reverse = tf.less(
tf.random_uniform([], 0, 1),
reverse_time_series_prob,
name="should_reverse")
# Reorganize outputs.
output = {}
for feature_name, value in parsed_features.iteritems():
if include_labels and feature_name == input_config.label_feature:
label_id = label_to_id.lookup(value)
# Ensure that the label_id is nonnegative to verify a successful hash
# map lookup.
assert_known_label = tf.Assert(
tf.greater_equal(label_id, tf.to_int32(0)),
["Unknown label string:", value])
with tf.control_dependencies([assert_known_label]):
label_id = tf.identity(label_id)
# We use the plural name "labels" in the output due to batching.
output["labels"] = label_id
elif input_config.features[feature_name].is_time_series:
# Possibly reverse.
if reverse_time_series_prob > 0:
# pylint:disable=cell-var-from-loop
value = tf.cond(should_reverse, lambda: tf.reverse(value, axis=[0]),
lambda: tf.identity(value))
# pylint:enable=cell-var-from-loop
if "time_series_features" not in output:
output["time_series_features"] = {}
output["time_series_features"][feature_name] = value
else:
if "aux_features" not in output:
output["aux_features"] = {}
output["aux_features"][feature_name] = value
return output
# Create a string dataset of filenames, and possibly shuffle.
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
if len(filenames) > 1 and shuffle_filenames:
filename_dataset = filename_dataset.shuffle(len(filenames))
# Read serialized Example protos.
dataset = filename_dataset.flat_map(tf.data.TFRecordDataset)
# Possibly shuffle. Note that we shuffle before repeat(), so we only shuffle
# elements among each "epoch" of data, and not across epochs of data.
if shuffle_values_buffer > 0:
dataset = dataset.shuffle(shuffle_values_buffer)
# Repeat.
if repeat != 1:
dataset = dataset.repeat(repeat)
# Map the parser over the dataset.
dataset = dataset.map(_example_parser, num_parallel_calls=4)
# Batch results by up to batch_size.
dataset = dataset.batch(batch_size)
if repeat == -1 or repeat is None:
# The dataset repeats infinitely before batching, so each batch has the
# maximum number of elements.
dataset = set_batch_size(dataset, batch_size)
elif use_tpu:
# TPU requires all dimensions to be fixed. Since the dataset does not repeat
# infinitely before batching, the final batch may have fewer than batch_size
# elements. Therefore we pad to ensure that the final batch has batch_size
# elements.
dataset = pad_dataset_to_batch_size(dataset, batch_size)
# Prefetch a few batches.
dataset = dataset.prefetch(max(1, int(256 / batch_size)))
return dataset
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from absl import flags
import numpy as np
import tensorflow as tf
from astronet.ops import dataset_ops
from astronet.util import configdict
FLAGS = flags.FLAGS
flags.DEFINE_string("test_srcdir", "", "Test source directory.")
_TEST_TFRECORD_FILE = "astronet/ops/test_data/test_dataset.tfrecord"
class DatasetOpsTest(tf.test.TestCase):
def testPadTensorToBatchSize(self):
with self.test_session():
# Cannot pad a 0-dimensional Tensor.
tensor_0d = tf.constant(1)
with self.assertRaises(ValueError):
dataset_ops.pad_tensor_to_batch_size(tensor_0d, 10)
# 1-dimensional Tensor. Un-padded batch size is 5.
tensor_1d = tf.range(5, dtype=tf.int32)
self.assertEqual([5], tensor_1d.shape)
self.assertAllEqual([0, 1, 2, 3, 4], tensor_1d.eval())
# Invalid to pad Tensor with batch size 5 to batch size 3.
tensor_1d_pad3 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 3)
with self.assertRaises(tf.errors.InvalidArgumentError):
tensor_1d_pad3.eval()
tensor_1d_pad5 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 5)
self.assertEqual([5], tensor_1d_pad5.shape)
self.assertAllEqual([0, 1, 2, 3, 4], tensor_1d_pad5.eval())
tensor_1d_pad8 = dataset_ops.pad_tensor_to_batch_size(tensor_1d, 8)
self.assertEqual([8], tensor_1d_pad8.shape)
self.assertAllEqual([0, 1, 2, 3, 4, 0, 0, 0], tensor_1d_pad8.eval())
# 2-dimensional Tensor. Un-padded batch size is 3.
tensor_2d = tf.reshape(tf.range(9, dtype=tf.int32), [3, 3])
self.assertEqual([3, 3], tensor_2d.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], tensor_2d.eval())
tensor_2d_pad2 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 2)
# Invalid to pad Tensor with batch size 2 to batch size 2.
with self.assertRaises(tf.errors.InvalidArgumentError):
tensor_2d_pad2.eval()
tensor_2d_pad3 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 3)
self.assertEqual([3, 3], tensor_2d_pad3.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]],
tensor_2d_pad3.eval())
tensor_2d_pad4 = dataset_ops.pad_tensor_to_batch_size(tensor_2d, 4)
self.assertEqual([4, 3], tensor_2d_pad4.shape)
self.assertAllEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8], [0, 0, 0]],
tensor_2d_pad4.eval())
def testPadDatasetToBatchSizeNoWeights(self):
values = {"labels": np.arange(10, dtype=np.int32)}
dataset = tf.data.Dataset.from_tensor_slices(values).batch(4)
self.assertItemsEqual(["labels"], dataset.output_shapes.keys())
self.assertFalse(dataset.output_shapes["labels"].is_fully_defined())
dataset_pad = dataset_ops.pad_dataset_to_batch_size(dataset, 4)
self.assertItemsEqual(["labels", "weights"],
dataset_pad.output_shapes.keys())
self.assertEqual([4], dataset_pad.output_shapes["labels"])
self.assertEqual([4], dataset_pad.output_shapes["weights"])
next_batch = dataset_pad.make_one_shot_iterator().get_next()
next_labels = next_batch["labels"]
next_weights = next_batch["weights"]
with self.test_session() as sess:
labels, weights = sess.run([next_labels, next_weights])
self.assertAllEqual([0, 1, 2, 3], labels)
self.assertAllClose([1, 1, 1, 1], weights)
labels, weights = sess.run([next_labels, next_weights])
self.assertAllEqual([4, 5, 6, 7], labels)
self.assertAllClose([1, 1, 1, 1], weights)
labels, weights = sess.run([next_labels, next_weights])
self.assertAllEqual([8, 9, 0, 0], labels)
self.assertAllClose([1, 1, 0, 0], weights)
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run([next_labels, next_weights])
def testPadDatasetToBatchSizeWithWeights(self):
values = {
"labels": np.arange(10, dtype=np.int32),
"weights": 100 + np.arange(10, dtype=np.int32)
}
dataset = tf.data.Dataset.from_tensor_slices(values).batch(4)
self.assertItemsEqual(["labels", "weights"], dataset.output_shapes.keys())
self.assertFalse(dataset.output_shapes["labels"].is_fully_defined())
self.assertFalse(dataset.output_shapes["weights"].is_fully_defined())
dataset_pad = dataset_ops.pad_dataset_to_batch_size(dataset, 4)
self.assertItemsEqual(["labels", "weights"],
dataset_pad.output_shapes.keys())
self.assertEqual([4], dataset_pad.output_shapes["labels"])
self.assertEqual([4], dataset_pad.output_shapes["weights"])
next_batch = dataset_pad.make_one_shot_iterator().get_next()
next_labels = next_batch["labels"]
next_weights = next_batch["weights"]
with self.test_session() as sess:
labels, weights = sess.run([next_labels, next_weights])
self.assertAllEqual([0, 1, 2, 3], labels)
self.assertAllEqual([100, 101, 102, 103], weights)
labels, weights = sess.run([next_labels, next_weights])
self.assertAllEqual([4, 5, 6, 7], labels)
self.assertAllEqual([104, 105, 106, 107], weights)
labels, weights = sess.run([next_labels, next_weights])
self.assertAllEqual([8, 9, 0, 0], labels)
self.assertAllEqual([108, 109, 0, 0], weights)
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run([next_labels, next_weights])
def testSetBatchSizeSingleTensor1d(self):
dataset = tf.data.Dataset.range(4).batch(2)
self.assertFalse(dataset.output_shapes.is_fully_defined())
dataset = dataset_ops.set_batch_size(dataset, 2)
self.assertEqual([2], dataset.output_shapes)
next_batch = dataset.make_one_shot_iterator().get_next()
with self.test_session() as sess:
batch_value = sess.run(next_batch)
self.assertAllEqual([0, 1], batch_value)
batch_value = sess.run(next_batch)
self.assertAllEqual([2, 3], batch_value)
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(next_batch)
def testSetBatchSizeSingleTensor2d(self):
values = np.arange(12, dtype=np.int32).reshape([4, 3])
dataset = tf.data.Dataset.from_tensor_slices(values).batch(2)
self.assertFalse(dataset.output_shapes.is_fully_defined())
dataset = dataset_ops.set_batch_size(dataset, 2)
self.assertEqual([2, 3], dataset.output_shapes)
next_batch = dataset.make_one_shot_iterator().get_next()
with self.test_session() as sess:
batch_value = sess.run(next_batch)
self.assertAllEqual([[0, 1, 2], [3, 4, 5]], batch_value)
batch_value = sess.run(next_batch)
self.assertAllEqual([[6, 7, 8], [9, 10, 11]], batch_value)
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(next_batch)
def testSetBatchSizeNested(self):
values = {
"a": 100 + np.arange(4, dtype=np.int32),
"nest": {
"b": np.arange(12, dtype=np.int32).reshape([4, 3]),
"c": np.arange(4, dtype=np.int32)
}
}
dataset = tf.data.Dataset.from_tensor_slices(values).batch(2)
self.assertItemsEqual(["a", "nest"], dataset.output_shapes.keys())
self.assertItemsEqual(["b", "c"], dataset.output_shapes["nest"].keys())
self.assertFalse(dataset.output_shapes["a"].is_fully_defined())
self.assertFalse(dataset.output_shapes["nest"]["b"].is_fully_defined())
self.assertFalse(dataset.output_shapes["nest"]["c"].is_fully_defined())
dataset = dataset_ops.set_batch_size(dataset, 2)
self.assertItemsEqual(["a", "nest"], dataset.output_shapes.keys())
self.assertItemsEqual(["b", "c"], dataset.output_shapes["nest"].keys())
self.assertEqual([2], dataset.output_shapes["a"])
self.assertEqual([2, 3], dataset.output_shapes["nest"]["b"])
self.assertEqual([2], dataset.output_shapes["nest"]["c"])
next_batch = dataset.make_one_shot_iterator().get_next()
next_a = next_batch["a"]
next_b = next_batch["nest"]["b"]
next_c = next_batch["nest"]["c"]
with self.test_session() as sess:
a, b, c = sess.run([next_a, next_b, next_c])
self.assertAllEqual([100, 101], a)
self.assertAllEqual([[0, 1, 2], [3, 4, 5]], b)
self.assertAllEqual([0, 1], c)
a, b, c = sess.run([next_a, next_b, next_c])
self.assertAllEqual([102, 103], a)
self.assertAllEqual([[6, 7, 8], [9, 10, 11]], b)
self.assertAllEqual([2, 3], c)
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(next_batch)
class BuildDatasetTest(tf.test.TestCase):
def setUp(self):
super(BuildDatasetTest, self).setUp()
# The test dataset contains 10 tensorflow.Example protocol buffers. The i-th
# Example contains the following features:
# global_view = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
# local_view = [0.0, 1.0, 2.0, 3.0]
# aux_feature = 100 + i
# label_str = "PC" if i % 3 == 0 else "AFP" if i % 3 == 1 else "NTP"
self._file_pattern = os.path.join(FLAGS.test_srcdir, _TEST_TFRECORD_FILE)
self._input_config = configdict.ConfigDict({
"features": {
"global_view": {
"is_time_series": True,
"length": 8
},
"local_view": {
"is_time_series": True,
"length": 4
},
"aux_feature": {
"is_time_series": False,
"length": 1
}
}
})
def testNonExistentFileRaisesValueError(self):
with self.assertRaises(ValueError):
dataset_ops.build_dataset(
file_pattern="nonexistent",
input_config=self._input_config,
batch_size=4)
def testBuildWithoutLabels(self):
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4,
include_labels=False)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
# Expect features only.
self.assertItemsEqual(["time_series_features", "aux_features"],
features.keys())
with self.test_session() as sess:
# Batch 1.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[100], [101], [102], [103]],
f["aux_features"]["aux_feature"])
# Batch 2.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[104], [105], [106], [107]],
f["aux_features"]["aux_feature"])
# Batch 3.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[108], [109]],
f["aux_features"]["aux_feature"])
# No more batches.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(features)
def testLabels1(self):
self._input_config["label_feature"] = "label_str"
self._input_config["label_map"] = {"PC": 0, "AFP": 1, "NTP": 2}
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4)
# We need an initializable iterator when using labels because of the
# stateful label id hash table.
iterator = dataset.make_initializable_iterator()
inputs = iterator.get_next()
init_op = tf.tables_initializer()
# Expect features and labels.
self.assertItemsEqual(["time_series_features", "aux_features", "labels"],
inputs.keys())
labels = inputs["labels"]
with self.test_session() as sess:
sess.run([init_op, iterator.initializer])
# Fetch 3 batches.
np.testing.assert_array_equal([0, 1, 2, 0], sess.run(labels))
np.testing.assert_array_equal([1, 2, 0, 1], sess.run(labels))
np.testing.assert_array_equal([2, 0], sess.run(labels))
# No more batches.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(labels)
def testLabels2(self):
self._input_config["label_feature"] = "label_str"
self._input_config["label_map"] = {"PC": 1, "AFP": 0, "NTP": 0}
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4)
# We need an initializable iterator when using labels because of the
# stateful label id hash table.
iterator = dataset.make_initializable_iterator()
inputs = iterator.get_next()
init_op = tf.tables_initializer()
# Expect features and labels.
self.assertItemsEqual(["time_series_features", "aux_features", "labels"],
inputs.keys())
labels = inputs["labels"]
with self.test_session() as sess:
sess.run([init_op, iterator.initializer])
# Fetch 3 batches.
np.testing.assert_array_equal([1, 0, 0, 1], sess.run(labels))
np.testing.assert_array_equal([0, 0, 1, 0], sess.run(labels))
np.testing.assert_array_equal([0, 1], sess.run(labels))
# No more batches.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(labels)
def testBadLabelIdsRaisesValueError(self):
self._input_config["label_feature"] = "label_str"
# Label ids should be contiguous integers starting at 0.
self._input_config["label_map"] = {"PC": 1, "AFP": 2, "NTP": 3}
with self.assertRaises(ValueError):
dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4)
def testUnknownLabel(self):
self._input_config["label_feature"] = "label_str"
# label_map does not include "NTP".
self._input_config["label_map"] = {"PC": 1, "AFP": 0}
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4)
# We need an initializable iterator when using labels because of the
# stateful label id hash table.
iterator = dataset.make_initializable_iterator()
inputs = iterator.get_next()
init_op = tf.tables_initializer()
# Expect features and labels.
self.assertItemsEqual(["time_series_features", "aux_features", "labels"],
inputs.keys())
labels = inputs["labels"]
with self.test_session() as sess:
sess.run([init_op, iterator.initializer])
# Unknown label "NTP".
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(labels)
def testReverseTimeSeries(self):
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4,
reverse_time_series_prob=1,
include_labels=False)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
# Expect features only.
self.assertItemsEqual(["time_series_features", "aux_features"],
features.keys())
with self.test_session() as sess:
# Batch 1.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[3, 2, 1, 0],
[3, 2, 1, 0],
[3, 2, 1, 0],
[3, 2, 1, 0],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[100], [101], [102], [103]],
f["aux_features"]["aux_feature"])
# Batch 2.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[3, 2, 1, 0],
[3, 2, 1, 0],
[3, 2, 1, 0],
[3, 2, 1, 0],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[104], [105], [106], [107]],
f["aux_features"]["aux_feature"])
# Batch 3.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[7, 6, 5, 4, 3, 2, 1, 0],
[7, 6, 5, 4, 3, 2, 1, 0],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[3, 2, 1, 0],
[3, 2, 1, 0],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[108], [109]],
f["aux_features"]["aux_feature"])
# No more batches.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(features)
def testRepeat(self):
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4,
include_labels=False)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
# Expect features only.
self.assertItemsEqual(["time_series_features", "aux_features"],
features.keys())
with self.test_session() as sess:
# Batch 1.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[100], [101], [102], [103]],
f["aux_features"]["aux_feature"])
# Batch 2.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[104], [105], [106], [107]],
f["aux_features"]["aux_feature"])
# Batch 3.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[108], [109]],
f["aux_features"]["aux_feature"])
# No more batches.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(features)
def testTPU(self):
dataset = dataset_ops.build_dataset(
file_pattern=self._file_pattern,
input_config=self._input_config,
batch_size=4,
include_labels=False)
# We can use a one-shot iterator without labels because we don't have the
# stateful hash map for label ids.
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
# Expect features only.
self.assertItemsEqual(["time_series_features", "aux_features"],
features.keys())
with self.test_session() as sess:
# Batch 1.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[100], [101], [102], [103]],
f["aux_features"]["aux_feature"])
# Batch 2.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[104], [105], [106], [107]],
f["aux_features"]["aux_feature"])
# Batch 3.
f = sess.run(features)
np.testing.assert_array_almost_equal([
[0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7],
], f["time_series_features"]["global_view"])
np.testing.assert_array_almost_equal([
[0, 1, 2, 3],
[0, 1, 2, 3],
], f["time_series_features"]["local_view"])
np.testing.assert_array_almost_equal([[108], [109]],
f["aux_features"]["aux_feature"])
# No more batches.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(features)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operations for feeding input data using TensorFlow placeholders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def prepare_feed_dict(model, features, labels=None, is_training=None):
"""Prepares a feed_dict for sess.run() given a batch of features and labels.
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape
[batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
"""
feed_dict = {}
for feature, tensor in model.time_series_features.iteritems():
feed_dict[tensor] = features["time_series_features"][feature]
for feature, tensor in model.aux_features.iteritems():
feed_dict[tensor] = features["aux_features"][feature]
if labels is not None:
feed_dict[model.labels] = labels
if is_training is not None:
feed_dict[model.is_training] = is_training
return feed_dict
def build_feature_placeholders(config):
"""Builds tf.Placeholder ops for feeding model features and labels.
Args:
config: ConfigDict containing the feature configurations.
Returns:
features: A dictionary containing "time_series_features" and "aux_features",
each of which is a dictionary of tf.Placeholders of features from the
input configuration. All features have dtype float32 and shape
[batch_size, length].
"""
batch_size = None # Batch size will be dynamically specified.
features = {"time_series_features": {}, "aux_features": {}}
for feature_name, feature_spec in config.iteritems():
placeholder = tf.placeholder(
dtype=tf.float32,
shape=[batch_size, feature_spec.length],
name=feature_name)
if feature_spec.is_time_series:
features["time_series_features"][feature_name] = placeholder
else:
features["aux_features"][feature_name] = placeholder
return features
def build_labels_placeholder():
"""Builds a tf.Placeholder op for feeding model labels.
Returns:
labels: An int64 tf.Placeholder with shape [batch_size].
"""
batch_size = None # Batch size will be dynamically specified.
return tf.placeholder(dtype=tf.int64, shape=[batch_size], name="labels")
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for input_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astronet.ops import input_ops
from astronet.util import configdict
class InputOpsTest(tf.test.TestCase):
def assertFeatureShapesEqual(self, expected_shapes, features):
"""Asserts that a dict of feature placeholders has the expected shapes.
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
"""
actual_shapes = {}
for feature_type in features:
actual_shapes[feature_type] = {
feature: tensor.shape.as_list()
for feature, tensor in features[feature_type].iteritems()
}
self.assertDictEqual(expected_shapes, actual_shapes)
def testBuildFeaturePlaceholders(self):
# One time series feature.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
}
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
},
"aux_features": {}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
# Two time series features.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
},
"time_feature_2": {
"length": 5,
"is_time_series": True,
}
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
"time_feature_2": [None, 5],
},
"aux_features": {}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
# One aux feature.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
},
"aux_feature_1": {
"length": 1,
"is_time_series": False,
}
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
},
"aux_features": {
"aux_feature_1": [None, 1]
}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
# Two aux features.
config = configdict.ConfigDict({
"time_feature_1": {
"length": 14,
"is_time_series": True,
},
"aux_feature_1": {
"length": 1,
"is_time_series": False,
},
"aux_feature_2": {
"length": 6,
"is_time_series": False,
},
})
expected_shapes = {
"time_series_features": {
"time_feature_1": [None, 14],
},
"aux_features": {
"aux_feature_1": [None, 1],
"aux_feature_2": [None, 6]
}
}
features = input_ops.build_feature_placeholders(config)
self.assertFeatureShapesEqual(expected_shapes, features)
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for computing evaluation metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def _metric_variable(name, shape, dtype):
"""Creates a Variable in LOCAL_VARIABLES and METRIC_VARIABLES collections."""
return tf.get_variable(
name,
initializer=tf.zeros(shape, dtype),
trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES])
def _build_metrics(labels, predictions, weights, batch_losses):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
labels: Tensor with shape [batch_size].
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert len(predictions.shape) == 2
binary_classification = (predictions.shape[1] == 1)
if binary_classification:
predictions = tf.squeeze(predictions, axis=[1])
predicted_labels = tf.to_int32(
tf.greater(predictions, 0.5), name="predicted_labels")
else:
predicted_labels = tf.argmax(
predictions, 1, name="predicted_labels", output_type=tf.int32)
metrics = {}
with tf.variable_scope("metrics"):
# Total number of examples.
num_examples = _metric_variable("num_examples", [], tf.float32)
update_num_examples = tf.assign_add(num_examples, tf.reduce_sum(weights))
metrics["num_examples"] = (num_examples.read_value(), update_num_examples)
# Accuracy metrics.
num_correct = _metric_variable("num_correct", [], tf.float32)
is_correct = weights * tf.to_float(tf.equal(labels, predicted_labels))
update_num_correct = tf.assign_add(num_correct, tf.reduce_sum(is_correct))
metrics["accuracy/num_correct"] = (num_correct.read_value(),
update_num_correct)
accuracy = tf.div(num_correct, num_examples, name="accuracy")
metrics["accuracy/accuracy"] = (accuracy, tf.no_op())
# Weighted cross-entropy loss.
metrics["losses/weighted_cross_entropy"] = tf.metrics.mean(
batch_losses, weights=weights, name="cross_entropy_loss")
# Possibly create additional metrics for binary classification.
if binary_classification:
labels = tf.cast(labels, dtype=tf.bool)
predicted_labels = tf.cast(predicted_labels, dtype=tf.bool)
# AUC.
metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000)
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
metrics["confusion_matrix/true_positives"] = _count_condition(
"true_positives", labels_value=True, predicted_value=True)
metrics["confusion_matrix/false_positives"] = _count_condition(
"false_positives", labels_value=False, predicted_value=True)
metrics["confusion_matrix/true_negatives"] = _count_condition(
"true_negatives", labels_value=False, predicted_value=False)
metrics["confusion_matrix/false_negatives"] = _count_condition(
"false_negatives", labels_value=True, predicted_value=False)
return metrics
def create_metric_fn(model):
"""Creates a tuple (metric_fn, metric_fn_inputs).
This function is primarily used for creating a TPUEstimator.
The result of calling metric_fn(**metric_fn_inputs) is a dictionary
{metric_name: (metric_value, update_op)}.
Args:
model: Instance of AstroModel.
Returns:
A tuple (metric_fn, metric_fn_inputs).
"""
weights = model.weights
if weights is None:
weights = tf.ones_like(model.labels, dtype=tf.float32)
metric_fn_inputs = {
"labels": model.labels,
"predictions": model.predictions,
"weights": weights,
"batch_losses": model.batch_losses,
}
def metric_fn(labels, predictions, weights, batch_losses):
return _build_metrics(labels, predictions, weights, batch_losses)
return metric_fn, metric_fn_inputs
def create_metrics(model):
"""Creates a dictionary {metric_name: (metric_value, update_op)}.
This function is primarily used for creating an Estimator.
Args:
model: Instance of AstroModel.
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
metric_fn, metric_fn_inputs = create_metric_fn(model)
return metric_fn(**metric_fn_inputs)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for metrics.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astronet.ops import metrics
def _unpack_metric_map(names_to_tuples):
"""Unpacks {metric_name: (metric_value, update_op)} into separate dicts."""
metric_names = names_to_tuples.keys()
value_ops, update_ops = zip(*names_to_tuples.values())
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
class _MockModel(object):
"""Mock model for testing."""
def __init__(self, labels, predictions, weights, batch_losses):
self.labels = tf.constant(labels, dtype=tf.int32)
self.predictions = tf.constant(predictions, dtype=tf.float32)
self.weights = None if weights is None else tf.constant(
weights, dtype=tf.float32)
self.batch_losses = tf.constant(batch_losses, dtype=tf.float32)
class MetricsTest(tf.test.TestCase):
def testMultiClassificationWithoutWeights(self):
labels = [0, 1, 2, 3]
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 8,
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
}, sess.run(value_ops))
def testMultiClassificationWithWeights(self):
labels = [0, 1, 2, 3]
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 2,
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
}, sess.run(value_ops))
def testBinaryClassificationWithoutWeights(self):
labels = [0, 1, 1, 0]
predictions = [
[0.4], # Predicted label = 0
[0.6], # Predicted label = 1
[0.0], # Predicted label = 0
[1.0], # Predicted label = 1
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 1,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 1,
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 8,
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 2,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 2,
}, sess.run(value_ops))
def testBinaryClassificationWithWeights(self):
labels = [0, 1, 1, 0]
predictions = [
[0.4], # Predicted label = 0
[0.6], # Predicted label = 1
[0.0], # Predicted label = 0
[1.0], # Predicted label = 1
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
with self.test_session() as sess:
sess.run(initializer)
sess.run(update_ops)
self.assertAllClose({
"num_examples": 2,
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 0,
}, sess.run(value_ops))
sess.run(update_ops)
self.assertAllClose({
"num_examples": 4,
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 0,
}, sess.run(value_ops))
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow utilities for unit tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def get_variable_by_name(name, scope=""):
"""Gets a tf.Variable by name.
Args:
name: Name of the Variable within the specified scope.
scope: Variable scope; use the empty string for top-level scope.
Returns:
The matching tf.Variable object.
"""
with tf.variable_scope(scope, reuse=True):
return tf.get_variable(name)
def fake_features(feature_spec, batch_size):
"""Creates random numpy arrays representing input features for unit testing.
Args:
feature_spec: Dictionary containing the feature specifications.
batch_size: Integer batch size.
Returns:
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features = {}
features["time_series_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.iteritems() if spec["is_time_series"]
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.iteritems() if not spec["is_time_series"]
}
return features
def fake_labels(output_dim, batch_size):
"""Creates a radom numpy array representing labels for unit testing.
Args:
output_dim: Number of output units in the classification model.
batch_size: Integer batch size.
Returns:
Numpy array of shape [batch_size].
"""
# Binary classification is denoted by output_dim == 1. In that case there are
# 2 label classes even though there is only 1 output prediction by the model.
# Otherwise, the classification task is multi-labeled with output_dim classes.
num_labels = 2 if output_dim == 1 else output_dim
return np.random.randint(num_labels, size=batch_size)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training an AstroNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def create_learning_rate(hparams, global_step):
"""Creates a learning rate Tensor.
Args:
hparams: ConfigDict containing the learning rate configuration.
global_step: The global step Tensor.
Returns:
A learning rate Tensor.
"""
if hparams.get("learning_rate_decay_factor"):
learning_rate = tf.train.exponential_decay(
learning_rate=float(hparams.learning_rate),
global_step=global_step,
decay_steps=hparams.learning_rate_decay_steps,
decay_rate=hparams.learning_rate_decay_factor,
staircase=hparams.learning_rate_decay_staircase)
else:
learning_rate = tf.constant(hparams.learning_rate)
return learning_rate
def create_optimizer(hparams, learning_rate, use_tpu=False):
"""Creates a TensorFlow Optimizer.
Args:
hparams: ConfigDict containing the optimizer configuration.
learning_rate: A Python float or a scalar Tensor.
use_tpu: If True, the returned optimizer is wrapped in a
CrossShardOptimizer.
Returns:
A TensorFlow optimizer.
Raises:
ValueError: If hparams.optimizer is unrecognized.
"""
optimizer_name = hparams.optimizer.lower()
if optimizer_name == "momentum":
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=hparams.get("momentum", 0.9),
use_nesterov=hparams.get("use_nesterov", False))
elif optimizer_name == "sgd":
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
elif optimizer_name == "adagrad":
optimizer = tf.train.AdagradOptimizer(learning_rate)
elif optimizer_name == "adam":
optimizer = tf.train.AdamOptimizer(learning_rate)
elif optimizer_name == "rmsprop":
optimizer = tf.RMSPropOptimizer(learning_rate)
else:
raise ValueError("Unknown optimizer: %s" % hparams.optimizer)
if use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
return optimizer
def create_train_op(model, optimizer):
"""Creates a Tensor to train the model.
Args:
model: Instance of AstroModel.
optimizer: Instance of tf.train.Optimizer.
Returns:
A Tensor that runs a single training step and returns model.total_loss.
"""
# Maybe clip gradient norms.
transform_grads_fn = None
if model.hparams.get("clip_grad_norm"):
transform_grads_fn = tf.contrib.training.clip_gradient_norms_fn(
model.hparams.clip_gradient_norm)
# Create train op.
return tf.contrib.training.create_train_op(
total_loss=model.total_loss,
optimizer=optimizer,
global_step=model.global_step,
transform_grads_fn=transform_grads_fn)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate predictions for a Threshold Crossing Event using a trained model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from astronet import models
from astronet.data import preprocess
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, required=True, help="Name of the model class.")
parser.add_argument(
"--config_name",
type=str,
help="Name of the model and training configuration. Exactly one of "
"--config_name or --config_json is required.")
parser.add_argument(
"--config_json",
type=str,
help="JSON string or JSON file containing the model and training "
"configuration. Exactly one of --config_name or --config_json is required.")
parser.add_argument(
"--model_dir",
type=str,
required=True,
help="Directory containing a model checkpoint.")
parser.add_argument(
"--kepler_data_dir",
type=str,
required=True,
help="Base folder containing Kepler data.")
parser.add_argument(
"--kepler_id",
type=int,
required=True,
help="Kepler ID of the target star.")
parser.add_argument(
"--period", type=float, required=True, help="Period of the TCE, in days.")
parser.add_argument("--t0", type=float, required=True, help="Epoch of the TCE.")
parser.add_argument(
"--duration",
type=float,
required=True,
help="Duration of the TCE, in days.")
parser.add_argument(
"--output_image_file",
type=str,
help="If specified, path to an output image file containing feature plots. "
"Must end in a valid image extension, e.g. png.")
def _process_tce(feature_config):
"""Reads and process the input features of a Threshold Crossing Event.
Args:
feature_config: ConfigDict containing the feature configurations.
Returns:
A dictionary of processed light curve features.
Raises:
ValueError: If feature_config contains features other than 'global_view'
and 'local_view'.
"""
if not {"global_view", "local_view"}.issuperset(feature_config.keys()):
raise ValueError(
"Only 'global_view' and 'local_view' features are supported.")
# Read and process the light curve.
time, flux = preprocess.read_and_process_light_curve(FLAGS.kepler_id,
FLAGS.kepler_data_dir)
time, flux = preprocess.phase_fold_and_sort_light_curve(
time, flux, FLAGS.period, FLAGS.t0)
# Generate the local and global views.
features = {}
if "global_view" in feature_config:
global_view = preprocess.global_view(time, flux, FLAGS.period)
# Add a batch dimension.
features["global_view"] = np.expand_dims(global_view, 0)
if "local_view" in feature_config:
local_view = preprocess.local_view(time, flux, FLAGS.period, FLAGS.duration)
# Add a batch dimension.
features["local_view"] = np.expand_dims(local_view, 0)
# Possibly save plots.
if FLAGS.output_image_file:
ncols = len(features)
fig, axes = plt.subplots(1, ncols, figsize=(10 * ncols, 5), squeeze=False)
for i, name in enumerate(sorted(features)):
ax = axes[0][i]
ax.plot(features[name][0], ".")
ax.set_title(name)
ax.set_xlabel("Bucketized Time (days)")
ax.set_ylabel("Normalized Flux")
fig.tight_layout()
fig.savefig(FLAGS.output_image_file, bbox_inches="tight")
return features
def main(_):
model_class = models.get_model_class(FLAGS.model)
# Look up the model configuration.
assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
"Exactly one of --config_name or --config_json is required.")
config = (
models.get_model_config(FLAGS.model, FLAGS.config_name)
if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
config = configdict.ConfigDict(config)
# Create the estimator.
estimator = estimator_util.create_estimator(
model_class, config.hparams, model_dir=FLAGS.model_dir)
# Read and process the input features.
features = _process_tce(config.inputs.features)
# Create an input function.
def input_fn():
return {
"time_series_features":
tf.estimator.inputs.numpy_input_fn(
features, batch_size=1, shuffle=False, queue_capacity=1)()
}
# Generate the predictions.
for predictions in estimator.predict(input_fn):
assert len(predictions) == 1
print("Prediction:", predictions[0])
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training an AstroNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from astronet import models
from astronet.util import config_util
from astronet.util import configdict
from astronet.util import estimator_util
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, required=True, help="Name of the model class.")
parser.add_argument(
"--config_name",
type=str,
help="Name of the model and training configuration. Exactly one of "
"--config_name or --config_json is required.")
parser.add_argument(
"--config_json",
type=str,
help="JSON string or JSON file containing the model and training "
"configuration. Exactly one of --config_name or --config_json is required.")
parser.add_argument(
"--train_files",
type=str,
required=True,
help="Comma-separated list of file patterns matching the TFRecord files in "
"the training dataset.")
parser.add_argument(
"--eval_files",
type=str,
help="Comma-separated list of file patterns matching the TFRecord files in "
"the validation dataset.")
parser.add_argument(
"--model_dir",
type=str,
required=True,
help="Directory for model checkpoints and summaries.")
parser.add_argument(
"--train_steps",
type=int,
default=10000,
help="Total number of steps to train the model for.")
parser.add_argument(
"--shuffle_buffer_size",
type=int,
default=15000,
help="Size of the shuffle buffer for the training dataset.")
def main(_):
model_class = models.get_model_class(FLAGS.model)
# Look up the model configuration.
assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
"Exactly one of --config_name or --config_json is required.")
config = (
models.get_model_config(FLAGS.model, FLAGS.config_name)
if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
config = configdict.ConfigDict(config)
config_util.log_and_save_config(config, FLAGS.model_dir)
# Create the estimator.
run_config = tf.estimator.RunConfig(keep_checkpoint_max=1)
estimator = estimator_util.create_estimator(model_class, config.hparams,
run_config, FLAGS.model_dir)
# Create an input function that reads the training dataset. We iterate through
# the dataset once at a time if we are alternating with evaluation, otherwise
# we iterate infinitely.
train_input_fn = estimator_util.create_input_fn(
file_pattern=FLAGS.train_files,
input_config=config.inputs,
mode=tf.estimator.ModeKeys.TRAIN,
shuffle_values_buffer=FLAGS.shuffle_buffer_size,
repeat=1 if FLAGS.eval_files else None)
if not FLAGS.eval_files:
estimator.train(train_input_fn, max_steps=FLAGS.train_steps)
else:
eval_input_fn = estimator_util.create_input_fn(
file_pattern=FLAGS.eval_files,
input_config=config.inputs,
mode=tf.estimator.ModeKeys.EVAL)
for _ in estimator_util.continuous_train_and_eval(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=FLAGS.train_steps):
# continuous_train_and_eval() yields evaluation metrics after each
# training epoch. We don't do anything here.
pass
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment