Commit 252e2d2e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Update formatting.

PiperOrigin-RevId: 213963765
parent 9d0f41b7
...@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase): ...@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args: Args:
shape: Numpy array or anything that can be converted to one. shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
be converted to one.
""" """
if isinstance(tensor_or_array, (np.ndarray, np.generic)): if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape) self.assertAllEqual(shape, tensor_or_array.shape)
......
...@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase): ...@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args: Args:
shape: Numpy array or anything that can be converted to one. shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
be converted to one.
""" """
if isinstance(tensor_or_array, (np.ndarray, np.generic)): if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape) self.assertAllEqual(shape, tensor_or_array.shape)
......
...@@ -73,7 +73,9 @@ class AstroModel(object): ...@@ -73,7 +73,9 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves.""" """A TensorFlow model for classifying astrophysical light curves."""
def __init__(self, features, labels, hparams, mode): def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build(). """Basic setup.
The actual TensorFlow graph is constructed in build().
Args: Args:
features: A dictionary containing "time_series_features" and features: A dictionary containing "time_series_features" and
...@@ -201,8 +203,7 @@ class AstroModel(object): ...@@ -201,8 +203,7 @@ class AstroModel(object):
if len(hidden_layers) == 1: if len(hidden_layers) == 1:
pre_logits_concat = hidden_layers[0][1] pre_logits_concat = hidden_layers[0][1]
else: else:
pre_logits_concat = tf.concat( pre_logits_concat = tf.concat([layer[1] for layer in hidden_layers],
[layer[1] for layer in hidden_layers],
axis=1, axis=1,
name="pre_logits_concat") name="pre_logits_concat")
......
...@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase): ...@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args: Args:
shape: Numpy array or anything that can be converted to one. shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
be converted to one.
""" """
if isinstance(tensor_or_array, (np.ndarray, np.generic)): if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape) self.assertAllEqual(shape, tensor_or_array.shape)
......
...@@ -88,7 +88,6 @@ import tensorflow as tf ...@@ -88,7 +88,6 @@ import tensorflow as tf
from astronet.data import preprocess from astronet.data import preprocess
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/" _DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
...@@ -225,10 +224,10 @@ def main(argv): ...@@ -225,10 +224,10 @@ def main(argv):
file_shards.append((train_tces[start:end], filename)) file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard. # Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir, file_shards.append((val_tces,
"val-00000-of-00001"))) os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir, file_shards.append((test_tces,
"test-00000-of-00001"))) os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards) num_file_shards = len(file_shards)
# Launch subprocesses for the file shards. # Launch subprocesses for the file shards.
......
...@@ -27,8 +27,7 @@ def prepare_feed_dict(model, features, labels=None, is_training=None): ...@@ -27,8 +27,7 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args: Args:
model: An instance of AstroModel. model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features". features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape Each is a dictionary of named numpy arrays of shape [batch_size, length].
[batch_size, length].
labels: (Optional). Numpy array of shape [batch_size]. labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed). Tensor (if None, no value is fed).
......
...@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size): ...@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length]. dictionary of named numpy arrays of shape [batch_size, length].
""" """
features = {} features = {"time_series_features": {}, "aux_features": {}}
features["time_series_features"] = { for name, spec in feature_spec.items():
name: np.random.random([batch_size, spec["length"]]) ftype = "time_series_features" if spec["is_time_series"] else "aux_features"
for name, spec in feature_spec.items() if spec["is_time_series"] features[ftype][name] = np.random.random([batch_size, spec["length"]])
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if not spec["is_time_series"]
}
return features return features
......
...@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase): ...@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected) np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
...@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase): ...@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np.testing.assert_almost_equal(folded_flux, expected_flux) np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
...@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase): ...@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected) np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment