Commit 252e2d2e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Update formatting.

PiperOrigin-RevId: 213963765
parent 9d0f41b7
......@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -73,17 +73,19 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
"aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT.
tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction.
for training, evaluation or prediction.
Raises:
ValueError: If mode is invalid.
......@@ -201,10 +203,9 @@ class AstroModel(object):
if len(hidden_layers) == 1:
pre_logits_concat = hidden_layers[0][1]
else:
pre_logits_concat = tf.concat(
[layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
pre_logits_concat = tf.concat([layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
net = pre_logits_concat
with tf.variable_scope("pre_logits_hidden"):
......
......@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -88,7 +88,6 @@ import tensorflow as tf
from astronet.data import preprocess
parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
......@@ -225,10 +224,10 @@ def main(argv):
file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir,
"val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir,
"test-00000-of-00001")))
file_shards.append((val_tces,
os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces,
os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards)
# Launch subprocesses for the file shards.
......
......@@ -34,7 +34,7 @@ def read_light_curve(kepid, kepler_data_dir):
Args:
kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames().
kepler_io.kepler_filenames().
Returns:
all_time: A list of numpy arrays; the time values of the raw light curve.
......@@ -59,7 +59,7 @@ def process_light_curve(all_time, all_flux):
Args:
all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in
all_time.
all_time.
Returns:
time: 1D NumPy array; the time values of the light curve.
......@@ -192,7 +192,7 @@ def local_view(time,
num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0).
event is assumed to be centered at 0).
Returns:
1D NumPy array of size num_bins containing the median flux values of
......@@ -214,7 +214,7 @@ def generate_example_for_tce(time, flux, tce):
time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output.
'tce_time0bk'. Additional items are included as features in the output.
Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all
......
......@@ -57,7 +57,7 @@ def get_model_config(model_name, config_name):
Args:
model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's
configurations module.
configurations module.
Returns:
model_class: The requested model class.
......
......@@ -142,19 +142,19 @@ def build_dataset(file_pattern,
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs.
epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely.
will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU.
Raises:
......
......@@ -27,11 +27,10 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape
[batch_size, length].
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
Tensor (if None, no value is fed).
Returns:
feed_dict: A dictionary of input Tensor to numpy array.
......
......@@ -31,9 +31,9 @@ class InputOpsTest(tf.test.TestCase):
Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'.
corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders().
input_ops.build_feature_placeholders().
"""
actual_shapes = {}
for feature_type in features:
......
......@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features = {}
features["time_series_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if spec["is_time_series"]
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if not spec["is_time_series"]
}
features = {"time_series_features": {}, "aux_features": {}}
for name, spec in feature_spec.items():
ftype = "time_series_features" if spec["is_time_series"] else "aux_features"
features[ftype][name] = np.random.random([batch_size, spec["length"]])
return features
......
......@@ -51,7 +51,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False):
hparams: ConfigDict containing the optimizer configuration.
learning_rate: A Python float or a scalar Tensor.
use_tpu: If True, the returned optimizer is wrapped in a
CrossShardOptimizer.
CrossShardOptimizer.
Returns:
A TensorFlow optimizer.
......
......@@ -105,7 +105,7 @@ def unflatten(flat_config):
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
......
......@@ -203,10 +203,10 @@ def create_estimator(model_class,
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
......
......@@ -28,7 +28,7 @@ def get_feature(ex, name, kind=None, strict=True):
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
......@@ -93,7 +93,7 @@ def set_feature(ex,
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
......
......@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -105,15 +105,15 @@ def kepler_filenames(base_dir,
Args:
base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string.
padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve.
curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return.
mission to return.
injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3".
"inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters).
exist (not all stars have data for all quarters).
Returns:
A list of filenames.
......
......@@ -32,16 +32,16 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value.
elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at
least 2.
least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x).
than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x).
greater than x_min. Defaults to max(x).
Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment