Commit 252e2d2e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Update formatting.

PiperOrigin-RevId: 213963765
parent 9d0f41b7
......@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -73,7 +73,9 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves."""
def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build().
"""Basic setup.
The actual TensorFlow graph is constructed in build().
Args:
features: A dictionary containing "time_series_features" and
......@@ -201,8 +203,7 @@ class AstroModel(object):
if len(hidden_layers) == 1:
pre_logits_concat = hidden_layers[0][1]
else:
pre_logits_concat = tf.concat(
[layer[1] for layer in hidden_layers],
pre_logits_concat = tf.concat([layer[1] for layer in hidden_layers],
axis=1,
name="pre_logits_concat")
......
......@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args:
shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can
be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
"""
if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape)
......
......@@ -88,7 +88,6 @@ import tensorflow as tf
from astronet.data import preprocess
parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
......@@ -225,10 +224,10 @@ def main(argv):
file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir,
"val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir,
"test-00000-of-00001")))
file_shards.append((val_tces,
os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces,
os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards)
# Launch subprocesses for the file shards.
......
......@@ -27,8 +27,7 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args:
model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape
[batch_size, length].
Each is a dictionary of named numpy arrays of shape [batch_size, length].
labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed).
......
......@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length].
"""
features = {}
features["time_series_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if spec["is_time_series"]
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if not spec["is_time_series"]
}
features = {"time_series_features": {}, "aux_features": {}}
for name, spec in feature_spec.items():
ftype = "time_series_features" if spec["is_time_series"] else "aux_features"
features[ftype][name] = np.random.random([batch_size, spec["length"]])
return features
......
......@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
......@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment