Unverified Commit 2c181308 authored by Chris Shallue's avatar Chris Shallue Committed by GitHub
Browse files

Merge pull request #5862 from cshallue/master

Move tensorflow_models/research/astronet to google-research/exoplanet-ml
parents caafb6d1 62704f06
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for model building, training and evaluation.
The default base configuration has one "global_view" time series feature per
input example. Additional time series features and auxiliary features can be
added.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def base():
"""Returns the base config for model building, training and evaluation."""
return {
# Configuration for reading input features and labels.
"inputs": {
# Feature specifications.
"features": {
"global_view": {
"length": 2001,
"is_time_series": True,
},
},
# Name of the feature containing training labels.
"label_feature": "av_training_set",
# Label string to integer id.
"label_map": {
"PC": 1, # Planet Candidate.
"AFP": 0, # Astrophysical False Positive.
"NTP": 0, # Non-Transiting Phenomenon.
"SCR1": 0, # TCE from scrambled light curve with SCR1 order.
"INV": 0, # TCE from inverted light curve.
"INJ1": 1, # Injected Planet.
},
},
# Hyperparameters for building and training the model.
"hparams": {
# Number of output dimensions (predictions) for the classification
# task. If >= 2 then a softmax output layer is used. If equal to 1
# then a sigmoid output layer is used.
"output_dim": 1,
# Fully connected layers before the logits layer.
"num_pre_logits_hidden_layers": 0,
"pre_logits_hidden_layer_size": 0,
"pre_logits_dropout_rate": 0.0,
# Number of examples per training batch.
"batch_size": 256,
# Learning rate parameters.
"learning_rate": 2e-4,
"learning_rate_decay_steps": 0,
"learning_rate_decay_factor": 0,
"learning_rate_decay_staircase": True,
# Optimizer for training the model.
"optimizer": "adam",
# If not None, gradient norms will be clipped to this value.
"clip_gradient_norm": None,
}
}
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_binary(
name = "generate_input_records",
srcs = ["generate_input_records.py"],
deps = [":preprocess"],
)
py_library(
name = "preprocess",
srcs = ["preprocess.py"],
deps = [
"//light_curve:kepler_io",
"//light_curve:median_filter",
"//light_curve:util",
"//tf_util:example_util",
"//third_party/kepler_spline",
],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Generates a bash script for downloading light curves.
The input to this script is a CSV file of Kepler targets, for example the DR24
TCE table, which can be downloaded in CSV format from the NASA Exoplanet Archive
at:
https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/nph-tblView?app=ExoTbls&config=q1_q17_dr24_tce
Example usage:
python generate_download_script.py \
--kepler_csv_file=dr24_tce.csv \
--download_dir=${HOME}/astronet/kepler
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import csv
import os
import stat
import sys
parser = argparse.ArgumentParser()
parser.add_argument(
"--kepler_csv_file",
type=str,
required=True,
help="CSV file containing Kepler targets to download. Must contain a "
"'kepid' column.")
parser.add_argument(
"--download_dir",
type=str,
required=True,
help="Directory into which the Kepler data will be downloaded.")
parser.add_argument(
"--output_file",
type=str,
default="get_kepler.sh",
help="Filename of the output download script.")
_WGET_CMD = ("wget -q -nH --cut-dirs=6 -r -l0 -c -N -np -erobots=off "
"-R 'index*' -A _llc.fits")
_BASE_URL = "http://archive.stsci.edu/pub/kepler/lightcurves"
def main(argv):
del argv # Unused.
# Read Kepler targets.
kepids = set()
with open(FLAGS.kepler_csv_file) as f:
reader = csv.DictReader(row for row in f if not row.startswith("#"))
for row in reader:
kepids.add(row["kepid"])
num_kepids = len(kepids)
# Write wget commands to script file.
with open(FLAGS.output_file, "w") as f:
f.write("#!/bin/sh\n")
f.write("echo 'Downloading {} Kepler targets to {}'\n".format(
num_kepids, FLAGS.download_dir))
for i, kepid in enumerate(kepids):
if i and not i % 10:
f.write("echo 'Downloaded {}/{}'\n".format(i, num_kepids))
kepid = "{0:09d}".format(int(kepid)) # Pad with zeros.
subdir = "{}/{}".format(kepid[0:4], kepid)
download_dir = os.path.join(FLAGS.download_dir, subdir)
url = "{}/{}/".format(_BASE_URL, subdir)
f.write("{} -P {} {}\n".format(_WGET_CMD, download_dir, url))
f.write("echo 'Finished downloading {} Kepler targets to {}'\n".format(
num_kepids, FLAGS.download_dir))
# Make the download script executable.
os.chmod(FLAGS.output_file, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
print("{} Kepler targets will be downloaded to {}".format(
num_kepids, FLAGS.output_file))
print("To start download, run:\n {}".format("./" + FLAGS.output_file
if "/" not in FLAGS.output_file
else FLAGS.output_file))
if __name__ == "__main__":
FLAGS, unparsed = parser.parse_known_args()
main(argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script to preprocesses data from the Kepler space telescope.
This script produces training, validation and test sets of labeled Kepler
Threshold Crossing Events (TCEs). A TCE is a detected periodic event on a
particular Kepler target star that may or may not be a transiting planet. Each
TCE in the output contains local and global views of its light curve; auxiliary
features such as period and duration; and a label indicating whether the TCE is
consistent with being a transiting planet. The data sets produced by this script
can be used to train and evaluate models that classify Kepler TCEs.
The input TCEs and their associated labels are specified by the DR24 TCE Table,
which can be downloaded in CSV format from the NASA Exoplanet Archive at:
https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/nph-tblView?app=ExoTbls&config=q1_q17_dr24_tce
The downloaded CSV file should contain at least the following column names:
rowid: Integer ID of the row in the TCE table.
kepid: Kepler ID of the target star.
tce_plnt_num: TCE number within the target star.
tce_period: Orbital period of the detected event, in days.
tce_time0bk: The time corresponding to the center of the first detected
traisit in Barycentric Julian Day (BJD) minus a constant offset of
2,454,833.0 days.
tce_duration: Duration of the detected transit, in hours.
av_training_set: Autovetter training set label; one of PC (planet candidate),
AFP (astrophysical false positive), NTP (non-transiting phenomenon),
UNK (unknown).
The Kepler light curves can be downloaded from the Mikulski Archive for Space
Telescopes (MAST) at:
http://archive.stsci.edu/pub/kepler/lightcurves.
The Kepler data is assumed to reside in a directory with the same structure as
the MAST archive. Specifically, the file names for a particular Kepler target
star should have the following format:
.../${kep_id:0:4}/${kep_id}/kplr${kep_id}-${quarter_prefix}_${type}.fits,
where:
kep_id is the Kepler id left-padded with zeros to length 9;
quarter_prefix is the file name quarter prefix;
type is one of "llc" (long cadence light curve) or "slc" (short cadence light
curve).
The output TFRecord file contains one serialized tensorflow.train.Example
protocol buffer for each TCE in the input CSV file. Each Example contains the
following light curve representations:
global_view: Vector of length 2001; the Global View of the TCE.
local_view: Vector of length 201; the Local View of the TCE.
In addition, each Example contains the value of each column in the input TCE CSV
file. Some of these features may be useful as auxiliary features to the model.
The columns include:
rowid: Integer ID of the row in the TCE table.
kepid: Kepler ID of the target star.
tce_plnt_num: TCE number within the target star.
av_training_set: Autovetter training set label.
tce_period: Orbital period of the detected event, in days.
...
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import multiprocessing
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
from astronet.data import preprocess
parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
"nph-tblView?app=ExoTbls&config=q1_q17_dr24_tce")
parser.add_argument(
"--input_tce_csv_file",
type=str,
required=True,
help="CSV file containing the Q1-Q17 DR24 Kepler TCE table. Must contain "
"columns: rowid, kepid, tce_plnt_num, tce_period, tce_duration, "
"tce_time0bk. Download from: {}".format(_DR24_TCE_URL))
parser.add_argument(
"--kepler_data_dir",
type=str,
required=True,
help="Base folder containing Kepler data.")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Directory in which to save the output.")
parser.add_argument(
"--num_train_shards",
type=int,
default=8,
help="Number of file shards to divide the training set into.")
parser.add_argument(
"--num_worker_processes",
type=int,
default=5,
help="Number of subprocesses for processing the TCEs in parallel.")
# Name and values of the column in the input CSV file to use as training labels.
_LABEL_COLUMN = "av_training_set"
_ALLOWED_LABELS = {"PC", "AFP", "NTP"}
def _process_tce(tce):
"""Processes the light curve for a Kepler TCE and returns an Example proto.
Args:
tce: Row of the input TCE table.
Returns:
A tensorflow.train.Example proto containing TCE features.
"""
all_time, all_flux = preprocess.read_light_curve(tce.kepid,
FLAGS.kepler_data_dir)
time, flux = preprocess.process_light_curve(all_time, all_flux)
return preprocess.generate_example_for_tce(time, flux, tce)
def _process_file_shard(tce_table, file_name):
"""Processes a single file shard.
Args:
tce_table: A Pandas DateFrame containing the TCEs in the shard.
file_name: The output TFRecord file.
"""
process_name = multiprocessing.current_process().name
shard_name = os.path.basename(file_name)
shard_size = len(tce_table)
tf.logging.info("%s: Processing %d items in shard %s", process_name,
shard_size, shard_name)
with tf.python_io.TFRecordWriter(file_name) as writer:
num_processed = 0
for _, tce in tce_table.iterrows():
example = _process_tce(tce)
if example is not None:
writer.write(example.SerializeToString())
num_processed += 1
if not num_processed % 10:
tf.logging.info("%s: Processed %d/%d items in shard %s", process_name,
num_processed, shard_size, shard_name)
tf.logging.info("%s: Wrote %d items in shard %s", process_name, shard_size,
shard_name)
def main(argv):
del argv # Unused.
# Make the output directory if it doesn't already exist.
tf.gfile.MakeDirs(FLAGS.output_dir)
# Read CSV file of Kepler KOIs.
tce_table = pd.read_csv(
FLAGS.input_tce_csv_file, index_col="rowid", comment="#")
tce_table["tce_duration"] /= 24 # Convert hours to days.
tf.logging.info("Read TCE CSV file with %d rows.", len(tce_table))
# Filter TCE table to allowed labels.
allowed_tces = tce_table[_LABEL_COLUMN].apply(lambda l: l in _ALLOWED_LABELS)
tce_table = tce_table[allowed_tces]
num_tces = len(tce_table)
tf.logging.info("Filtered to %d TCEs with labels in %s.", num_tces,
list(_ALLOWED_LABELS))
# Randomly shuffle the TCE table.
np.random.seed(123)
tce_table = tce_table.iloc[np.random.permutation(num_tces)]
tf.logging.info("Randomly shuffled TCEs.")
# Partition the TCE table as follows:
# train_tces = 80% of TCEs
# val_tces = 10% of TCEs (for validation during training)
# test_tces = 10% of TCEs (for final evaluation)
train_cutoff = int(0.80 * num_tces)
val_cutoff = int(0.90 * num_tces)
train_tces = tce_table[0:train_cutoff]
val_tces = tce_table[train_cutoff:val_cutoff]
test_tces = tce_table[val_cutoff:]
tf.logging.info(
"Partitioned %d TCEs into training (%d), validation (%d) and test (%d)",
num_tces, len(train_tces), len(val_tces), len(test_tces))
# Further split training TCEs into file shards.
file_shards = [] # List of (tce_table_shard, file_name).
boundaries = np.linspace(0, len(train_tces),
FLAGS.num_train_shards + 1).astype(np.int)
for i in range(FLAGS.num_train_shards):
start = boundaries[i]
end = boundaries[i + 1]
filename = os.path.join(
FLAGS.output_dir, "train-{:05d}-of-{:05d}".format(
i, FLAGS.num_train_shards))
file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard.
file_shards.append((val_tces,
os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces,
os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards)
# Launch subprocesses for the file shards.
num_processes = min(num_file_shards, FLAGS.num_worker_processes)
tf.logging.info("Launching %d subprocesses for %d total file shards",
num_processes, num_file_shards)
pool = multiprocessing.Pool(processes=num_processes)
async_results = [
pool.apply_async(_process_file_shard, file_shard)
for file_shard in file_shards
]
pool.close()
# Instead of pool.join(), we call async_result.get() to ensure any exceptions
# raised by the worker processes are also raised here.
for async_result in async_results:
async_result.get()
tf.logging.info("Finished processing %d total file shards", num_file_shards)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for reading and preprocessing light curves."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from light_curve import kepler_io
from light_curve import median_filter
from light_curve import util
from tf_util import example_util
from third_party.kepler_spline import kepler_spline
def read_light_curve(kepid, kepler_data_dir):
"""Reads a Kepler light curve.
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
Returns:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
Raises:
IOError: If the light curve files for this Kepler ID cannot be found.
"""
# Read the Kepler light curve.
file_names = kepler_io.kepler_filenames(kepler_data_dir, kepid)
if not file_names:
raise IOError("Failed to find .fits files in {} for Kepler ID {}".format(
kepler_data_dir, kepid))
return kepler_io.read_kepler_light_curve(file_names)
def process_light_curve(all_time, all_flux):
"""Removes low-frequency variability from a light curve.
Args:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
Returns:
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
"""
# Split on gaps.
all_time, all_flux = util.split(all_time, all_flux, gap_width=0.75)
# Fit a piecewise-cubic spline with default arguments.
spline = kepler_spline.fit_kepler_spline(all_time, all_flux, verbose=False)[0]
# Concatenate the piecewise light curve and spline.
time = np.concatenate(all_time)
flux = np.concatenate(all_flux)
spline = np.concatenate(spline)
# In rare cases the piecewise spline contains NaNs in places the spline could
# not be fit. We can't normalize those points if the spline isn't defined
# there. Instead we just remove them.
finite_i = np.isfinite(spline)
if not np.all(finite_i):
time = time[finite_i]
flux = flux[finite_i]
spline = spline[finite_i]
# "Flatten" the light curve (remove low-frequency variability) by dividing by
# the spline.
flux /= spline
return time, flux
def phase_fold_and_sort_light_curve(time, flux, period, t0):
"""Phase folds a light curve and sorts by ascending time.
Args:
time: 1D NumPy array of time values.
flux: 1D NumPy array of flux values.
period: A positive real scalar; the period to fold over.
t0: The center of the resulting folded vector; this value is mapped to 0.
Returns:
folded_time: 1D NumPy array of phase folded time values in
[-period / 2, period / 2), where 0 corresponds to t0 in the original
time array. Values are sorted in ascending order.
folded_flux: 1D NumPy array. Values are the same as the original input
array, but sorted by folded_time.
"""
# Phase fold time.
time = util.phase_fold_time(time, period, t0)
# Sort by ascending time.
sorted_i = np.argsort(time)
time = time[sorted_i]
flux = flux[sorted_i]
return time, flux
def generate_view(time, flux, num_bins, bin_width, t_min, t_max,
normalize=True):
"""Generates a view of a phase-folded light curve using a median filter.
Args:
time: 1D array of time values, sorted in ascending order.
flux: 1D array of flux values.
num_bins: The number of intervals to divide the time axis into.
bin_width: The width of each bin on the time axis.
t_min: The inclusive leftmost value to consider on the time axis.
t_max: The exclusive rightmost value to consider on the time axis.
normalize: Whether to center the median at 0 and minimum value at -1.
Returns:
1D NumPy array of size num_bins containing the median flux values of
uniformly spaced bins on the phase-folded time axis.
"""
view = median_filter.median_filter(time, flux, num_bins, bin_width, t_min,
t_max)
if normalize:
view -= np.median(view)
view /= np.abs(np.min(view))
return view
def global_view(time, flux, period, num_bins=2001, bin_width_factor=1 / 2001):
"""Generates a 'global view' of a phase folded light curve.
See Section 3.3 of Shallue & Vanderburg, 2018, The Astronomical Journal.
http://iopscience.iop.org/article/10.3847/1538-3881/aa9e09/meta
Args:
time: 1D array of time values, sorted in ascending order.
flux: 1D array of flux values.
period: The period of the event (in days).
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of period.
Returns:
1D NumPy array of size num_bins containing the median flux values of
uniformly spaced bins on the phase-folded time axis.
"""
return generate_view(
time,
flux,
num_bins=num_bins,
bin_width=period * bin_width_factor,
t_min=-period / 2,
t_max=period / 2)
def local_view(time,
flux,
period,
duration,
num_bins=201,
bin_width_factor=0.16,
num_durations=4):
"""Generates a 'local view' of a phase folded light curve.
See Section 3.3 of Shallue & Vanderburg, 2018, The Astronomical Journal.
http://iopscience.iop.org/article/10.3847/1538-3881/aa9e09/meta
Args:
time: 1D array of time values, sorted in ascending order.
flux: 1D array of flux values.
period: The period of the event (in days).
duration: The duration of the event (in days).
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0).
Returns:
1D NumPy array of size num_bins containing the median flux values of
uniformly spaced bins on the phase-folded time axis.
"""
return generate_view(
time,
flux,
num_bins=num_bins,
bin_width=duration * bin_width_factor,
t_min=max(-period / 2, -duration * num_durations),
t_max=min(period / 2, duration * num_durations))
def generate_example_for_tce(time, flux, tce):
"""Generates a tf.train.Example representing an input TCE.
Args:
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output.
Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all
values present in `tce`.
"""
period = tce["tce_period"]
duration = tce["tce_duration"]
t0 = tce["tce_time0bk"]
time, flux = phase_fold_and_sort_light_curve(time, flux, period, t0)
# Make output proto.
ex = tf.train.Example()
# Set time series features.
example_util.set_float_feature(ex, "global_view",
global_view(time, flux, period))
example_util.set_float_feature(ex, "local_view",
local_view(time, flux, period, duration))
# Set other features in `tce`.
for name, value in tce.items():
example_util.set_feature(ex, name, [value])
return ex
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for evaluating an AstroNet model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from astronet import models
from astronet.util import estimator_util
from tf_util import config_util
from tf_util import configdict
from tf_util import estimator_runner
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, required=True, help="Name of the model class.")
parser.add_argument(
"--config_name",
type=str,
help="Name of the model and training configuration. Exactly one of "
"--config_name or --config_json is required.")
parser.add_argument(
"--config_json",
type=str,
help="JSON string or JSON file containing the model and training "
"configuration. Exactly one of --config_name or --config_json is required.")
parser.add_argument(
"--eval_files",
type=str,
required=True,
help="Comma-separated list of file patterns matching the TFRecord files in "
"the evaluation dataset.")
parser.add_argument(
"--model_dir",
type=str,
required=True,
help="Directory containing a model checkpoint.")
parser.add_argument(
"--eval_name", type=str, default="test", help="Name of the evaluation set.")
def main(_):
model_class = models.get_model_class(FLAGS.model)
# Look up the model configuration.
assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
"Exactly one of --config_name or --config_json is required.")
config = (
models.get_model_config(FLAGS.model, FLAGS.config_name)
if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))
config = configdict.ConfigDict(config)
# Create the estimator.
estimator = estimator_util.create_estimator(
model_class, config.hparams, model_dir=FLAGS.model_dir)
# Create an input function that reads the evaluation dataset.
input_fn = estimator_util.create_input_fn(
file_pattern=FLAGS.eval_files,
input_config=config.inputs,
mode=tf.estimator.ModeKeys.EVAL)
# Run evaluation. This will log the result to stderr and also write a summary
# file in the model_dir.
eval_steps = None # Evaluate over all examples in the file.
eval_args = {FLAGS.eval_name: (input_fn, eval_steps)}
estimator_runner.evaluate(estimator, eval_args)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library of AstroNet models and configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from astronet.astro_cnn_model import astro_cnn_model
from astronet.astro_cnn_model import configurations as astro_cnn_configurations
from astronet.astro_fc_model import astro_fc_model
from astronet.astro_fc_model import configurations as astro_fc_configurations
from astronet.astro_model import astro_model
from astronet.astro_model import configurations as astro_configurations
# Dictionary of model name to (model_class, configuration_module).
_MODELS = {
"AstroModel": (astro_model.AstroModel, astro_configurations),
"AstroFCModel": (astro_fc_model.AstroFCModel, astro_fc_configurations),
"AstroCNNModel": (astro_cnn_model.AstroCNNModel, astro_cnn_configurations),
}
def get_model_class(model_name):
"""Looks up a model class by name.
Args:
model_name: Name of the model class.
Returns:
model_class: The requested model class.
Raises:
ValueError: If model_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: {}".format(model_name))
return _MODELS[model_name][0]
def get_model_config(model_name, config_name):
"""Looks up a model configuration by name.
Args:
model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's
configurations module.
Returns:
model_class: The requested model class.
config: The requested configuration.
Raises:
ValueError: If model_name or config_name is unrecognized.
"""
if model_name not in _MODELS:
raise ValueError("Unrecognized model name: {}".format(model_name))
config_module = _MODELS[model_name][1]
try:
return getattr(config_module, config_name)()
except AttributeError:
raise ValueError(
"Config name '{}' not found in configuration module: {}".format(
config_name, config_module.__name__))
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "input_ops",
srcs = ["input_ops.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "input_ops_test",
size = "small",
srcs = ["input_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":input_ops",
"//tf_util:configdict",
],
)
py_library(
name = "dataset_ops",
srcs = ["dataset_ops.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "dataset_ops_test",
size = "small",
srcs = ["dataset_ops_test.py"],
data = ["test_data/test_dataset.tfrecord"],
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
"//tf_util:configdict",
],
)
py_library(
name = "training",
srcs = ["training.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "testing",
srcs = ["testing.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "metrics",
srcs = ["metrics.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "metrics_test",
size = "small",
srcs = ["metrics_test.py"],
srcs_version = "PY2AND3",
deps = [":metrics"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment