"...git@developer.sourcefind.cn:modelzoo/dbnet_pytorch.git" did not exist on "97243508511f3a4922b313f831a2038084d752a9"
Commit 252e2d2e authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Update formatting.

PiperOrigin-RevId: 213963765
parent 9d0f41b7
...@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase): ...@@ -35,8 +35,7 @@ class AstroCNNModelTest(tf.test.TestCase):
Args: Args:
shape: Numpy array or anything that can be converted to one. shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
be converted to one.
""" """
if isinstance(tensor_or_array, (np.ndarray, np.generic)): if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape) self.assertAllEqual(shape, tensor_or_array.shape)
......
...@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase): ...@@ -35,8 +35,7 @@ class AstroFCModelTest(tf.test.TestCase):
Args: Args:
shape: Numpy array or anything that can be converted to one. shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
be converted to one.
""" """
if isinstance(tensor_or_array, (np.ndarray, np.generic)): if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape) self.assertAllEqual(shape, tensor_or_array.shape)
......
...@@ -73,17 +73,19 @@ class AstroModel(object): ...@@ -73,17 +73,19 @@ class AstroModel(object):
"""A TensorFlow model for classifying astrophysical light curves.""" """A TensorFlow model for classifying astrophysical light curves."""
def __init__(self, features, labels, hparams, mode): def __init__(self, features, labels, hparams, mode):
"""Basic setup. The actual TensorFlow graph is constructed in build(). """Basic setup.
The actual TensorFlow graph is constructed in build().
Args: Args:
features: A dictionary containing "time_series_features" and features: A dictionary containing "time_series_features" and
"aux_features", each of which is a dictionary of named input Tensors. "aux_features", each of which is a dictionary of named input Tensors.
All features have dtype float32 and shape [batch_size, length]. All features have dtype float32 and shape [batch_size, length].
labels: An int64 Tensor with shape [batch_size]. May be None if mode is labels: An int64 Tensor with shape [batch_size]. May be None if mode is
tf.estimator.ModeKeys.PREDICT. tf.estimator.ModeKeys.PREDICT.
hparams: A ConfigDict of hyperparameters for building the model. hparams: A ConfigDict of hyperparameters for building the model.
mode: A tf.estimator.ModeKeys to specify whether the graph should be built mode: A tf.estimator.ModeKeys to specify whether the graph should be built
for training, evaluation or prediction. for training, evaluation or prediction.
Raises: Raises:
ValueError: If mode is invalid. ValueError: If mode is invalid.
...@@ -201,10 +203,9 @@ class AstroModel(object): ...@@ -201,10 +203,9 @@ class AstroModel(object):
if len(hidden_layers) == 1: if len(hidden_layers) == 1:
pre_logits_concat = hidden_layers[0][1] pre_logits_concat = hidden_layers[0][1]
else: else:
pre_logits_concat = tf.concat( pre_logits_concat = tf.concat([layer[1] for layer in hidden_layers],
[layer[1] for layer in hidden_layers], axis=1,
axis=1, name="pre_logits_concat")
name="pre_logits_concat")
net = pre_logits_concat net = pre_logits_concat
with tf.variable_scope("pre_logits_hidden"): with tf.variable_scope("pre_logits_hidden"):
......
...@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase): ...@@ -35,8 +35,7 @@ class AstroModelTest(tf.test.TestCase):
Args: Args:
shape: Numpy array or anything that can be converted to one. shape: Numpy array or anything that can be converted to one.
tensor_or_array: tf.Tensor, tf.Variable, Numpy array or anything that can tensor_or_array: tf.Tensor, tf.Variable, or Numpy array.
be converted to one.
""" """
if isinstance(tensor_or_array, (np.ndarray, np.generic)): if isinstance(tensor_or_array, (np.ndarray, np.generic)):
self.assertAllEqual(shape, tensor_or_array.shape) self.assertAllEqual(shape, tensor_or_array.shape)
......
...@@ -88,7 +88,6 @@ import tensorflow as tf ...@@ -88,7 +88,6 @@ import tensorflow as tf
from astronet.data import preprocess from astronet.data import preprocess
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
_DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/" _DR24_TCE_URL = ("https://exoplanetarchive.ipac.caltech.edu/cgi-bin/TblView/"
...@@ -225,10 +224,10 @@ def main(argv): ...@@ -225,10 +224,10 @@ def main(argv):
file_shards.append((train_tces[start:end], filename)) file_shards.append((train_tces[start:end], filename))
# Validation and test sets each have a single shard. # Validation and test sets each have a single shard.
file_shards.append((val_tces, os.path.join(FLAGS.output_dir, file_shards.append((val_tces,
"val-00000-of-00001"))) os.path.join(FLAGS.output_dir, "val-00000-of-00001")))
file_shards.append((test_tces, os.path.join(FLAGS.output_dir, file_shards.append((test_tces,
"test-00000-of-00001"))) os.path.join(FLAGS.output_dir, "test-00000-of-00001")))
num_file_shards = len(file_shards) num_file_shards = len(file_shards)
# Launch subprocesses for the file shards. # Launch subprocesses for the file shards.
......
...@@ -34,7 +34,7 @@ def read_light_curve(kepid, kepler_data_dir): ...@@ -34,7 +34,7 @@ def read_light_curve(kepid, kepler_data_dir):
Args: Args:
kepid: Kepler id of the target star. kepid: Kepler id of the target star.
kepler_data_dir: Base directory containing Kepler data. See kepler_data_dir: Base directory containing Kepler data. See
kepler_io.kepler_filenames(). kepler_io.kepler_filenames().
Returns: Returns:
all_time: A list of numpy arrays; the time values of the raw light curve. all_time: A list of numpy arrays; the time values of the raw light curve.
...@@ -59,7 +59,7 @@ def process_light_curve(all_time, all_flux): ...@@ -59,7 +59,7 @@ def process_light_curve(all_time, all_flux):
Args: Args:
all_time: A list of numpy arrays; the time values of the raw light curve. all_time: A list of numpy arrays; the time values of the raw light curve.
all_flux: A list of numpy arrays corresponding to the time arrays in all_flux: A list of numpy arrays corresponding to the time arrays in
all_time. all_time.
Returns: Returns:
time: 1D NumPy array; the time values of the light curve. time: 1D NumPy array; the time values of the light curve.
...@@ -192,7 +192,7 @@ def local_view(time, ...@@ -192,7 +192,7 @@ def local_view(time,
num_bins: The number of intervals to divide the time axis into. num_bins: The number of intervals to divide the time axis into.
bin_width_factor: Width of the bins, as a fraction of duration. bin_width_factor: Width of the bins, as a fraction of duration.
num_durations: The number of durations to consider on either side of 0 (the num_durations: The number of durations to consider on either side of 0 (the
event is assumed to be centered at 0). event is assumed to be centered at 0).
Returns: Returns:
1D NumPy array of size num_bins containing the median flux values of 1D NumPy array of size num_bins containing the median flux values of
...@@ -214,7 +214,7 @@ def generate_example_for_tce(time, flux, tce): ...@@ -214,7 +214,7 @@ def generate_example_for_tce(time, flux, tce):
time: 1D NumPy array; the time values of the light curve. time: 1D NumPy array; the time values of the light curve.
flux: 1D NumPy array; the normalized flux values of the light curve. flux: 1D NumPy array; the normalized flux values of the light curve.
tce: Dict-like object containing at least 'tce_period', 'tce_duration', and tce: Dict-like object containing at least 'tce_period', 'tce_duration', and
'tce_time0bk'. Additional items are included as features in the output. 'tce_time0bk'. Additional items are included as features in the output.
Returns: Returns:
A tf.train.Example containing features 'global_view', 'local_view', and all A tf.train.Example containing features 'global_view', 'local_view', and all
......
...@@ -57,7 +57,7 @@ def get_model_config(model_name, config_name): ...@@ -57,7 +57,7 @@ def get_model_config(model_name, config_name):
Args: Args:
model_name: Name of the model class. model_name: Name of the model class.
config_name: Name of a configuration-builder function from the model's config_name: Name of a configuration-builder function from the model's
configurations module. configurations module.
Returns: Returns:
model_class: The requested model class. model_class: The requested model class.
......
...@@ -142,19 +142,19 @@ def build_dataset(file_pattern, ...@@ -142,19 +142,19 @@ def build_dataset(file_pattern,
Args: Args:
file_pattern: File pattern matching input TFRecord files, e.g. file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file "/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns. patterns.
input_config: ConfigDict containing feature and label specifications. input_config: ConfigDict containing feature and label specifications.
batch_size: The number of examples per batch. batch_size: The number of examples per batch.
include_labels: Whether to read labels from the input files. include_labels: Whether to read labels from the input files.
reverse_time_series_prob: If > 0, the time series features will be randomly reverse_time_series_prob: If > 0, the time series features will be randomly
reversed with this probability. Within a given example, either all time reversed with this probability. Within a given example, either all time
series features will be reversed, or none will be reversed. series features will be reversed, or none will be reversed.
shuffle_filenames: Whether to shuffle the order of TFRecord files between shuffle_filenames: Whether to shuffle the order of TFRecord files between
epochs. epochs.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size. shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the dataset repeat: The number of times to repeat the dataset. If None or -1 the dataset
will repeat indefinitely. will repeat indefinitely.
use_tpu: Whether to build the dataset for TPU. use_tpu: Whether to build the dataset for TPU.
Raises: Raises:
......
...@@ -27,11 +27,10 @@ def prepare_feed_dict(model, features, labels=None, is_training=None): ...@@ -27,11 +27,10 @@ def prepare_feed_dict(model, features, labels=None, is_training=None):
Args: Args:
model: An instance of AstroModel. model: An instance of AstroModel.
features: Dictionary containing "time_series_features" and "aux_features". features: Dictionary containing "time_series_features" and "aux_features".
Each is a dictionary of named numpy arrays of shape Each is a dictionary of named numpy arrays of shape [batch_size, length].
[batch_size, length].
labels: (Optional). Numpy array of shape [batch_size]. labels: (Optional). Numpy array of shape [batch_size].
is_training: (Optional). Python boolean to feed to the model.is_training is_training: (Optional). Python boolean to feed to the model.is_training
Tensor (if None, no value is fed). Tensor (if None, no value is fed).
Returns: Returns:
feed_dict: A dictionary of input Tensor to numpy array. feed_dict: A dictionary of input Tensor to numpy array.
......
...@@ -31,9 +31,9 @@ class InputOpsTest(tf.test.TestCase): ...@@ -31,9 +31,9 @@ class InputOpsTest(tf.test.TestCase):
Args: Args:
expected_shapes: Dictionary of expected Tensor shapes, as lists, expected_shapes: Dictionary of expected Tensor shapes, as lists,
corresponding to the structure of 'features'. corresponding to the structure of 'features'.
features: Dictionary of feature placeholders of the format returned by features: Dictionary of feature placeholders of the format returned by
input_ops.build_feature_placeholders(). input_ops.build_feature_placeholders().
""" """
actual_shapes = {} actual_shapes = {}
for feature_type in features: for feature_type in features:
......
...@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size): ...@@ -47,15 +47,10 @@ def fake_features(feature_spec, batch_size):
Dictionary containing "time_series_features" and "aux_features". Each is a Dictionary containing "time_series_features" and "aux_features". Each is a
dictionary of named numpy arrays of shape [batch_size, length]. dictionary of named numpy arrays of shape [batch_size, length].
""" """
features = {} features = {"time_series_features": {}, "aux_features": {}}
features["time_series_features"] = { for name, spec in feature_spec.items():
name: np.random.random([batch_size, spec["length"]]) ftype = "time_series_features" if spec["is_time_series"] else "aux_features"
for name, spec in feature_spec.items() if spec["is_time_series"] features[ftype][name] = np.random.random([batch_size, spec["length"]])
}
features["aux_features"] = {
name: np.random.random([batch_size, spec["length"]])
for name, spec in feature_spec.items() if not spec["is_time_series"]
}
return features return features
......
...@@ -51,7 +51,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False): ...@@ -51,7 +51,7 @@ def create_optimizer(hparams, learning_rate, use_tpu=False):
hparams: ConfigDict containing the optimizer configuration. hparams: ConfigDict containing the optimizer configuration.
learning_rate: A Python float or a scalar Tensor. learning_rate: A Python float or a scalar Tensor.
use_tpu: If True, the returned optimizer is wrapped in a use_tpu: If True, the returned optimizer is wrapped in a
CrossShardOptimizer. CrossShardOptimizer.
Returns: Returns:
A TensorFlow optimizer. A TensorFlow optimizer.
......
...@@ -105,7 +105,7 @@ def unflatten(flat_config): ...@@ -105,7 +105,7 @@ def unflatten(flat_config):
Args: Args:
flat_config: A dictionary with strings as keys where nested configuration flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names. parameters are represented with period-separated names.
Returns: Returns:
A dictionary nested according to the keys of the input dictionary. A dictionary nested according to the keys of the input dictionary.
......
...@@ -203,10 +203,10 @@ def create_estimator(model_class, ...@@ -203,10 +203,10 @@ def create_estimator(model_class,
hparams: ConfigDict of configuration parameters for building the model. hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig. run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config. explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size. hparams.batch_size.
Returns: Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
......
...@@ -28,7 +28,7 @@ def get_feature(ex, name, kind=None, strict=True): ...@@ -28,7 +28,7 @@ def get_feature(ex, name, kind=None, strict=True):
ex: A tf.train.Example. ex: A tf.train.Example.
name: Name of the feature to look up. name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified. not specified.
strict: Whether to raise a KeyError if there is no such feature. strict: Whether to raise a KeyError if there is no such feature.
Returns: Returns:
...@@ -93,7 +93,7 @@ def set_feature(ex, ...@@ -93,7 +93,7 @@ def set_feature(ex,
name: Name of the feature to set. name: Name of the feature to set.
value: Feature value to set. Must be a sequence. value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified. not specified.
allow_overwrite: Whether to overwrite the existing value of the feature. allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'. bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
......
...@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase): ...@@ -44,5 +44,5 @@ class MedianFilterTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected) np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
...@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase): ...@@ -66,5 +66,5 @@ class PhaseFoldAndSortLightCurveTest(absltest.TestCase):
np.testing.assert_almost_equal(folded_flux, expected_flux) np.testing.assert_almost_equal(folded_flux, expected_flux)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
...@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase): ...@@ -76,5 +76,5 @@ class ViewGeneratorTest(absltest.TestCase):
np.testing.assert_almost_equal(result, expected) np.testing.assert_almost_equal(result, expected)
if __name__ == '__main__': if __name__ == "__main__":
absltest.main() absltest.main()
...@@ -105,15 +105,15 @@ def kepler_filenames(base_dir, ...@@ -105,15 +105,15 @@ def kepler_filenames(base_dir,
Args: Args:
base_dir: Base directory containing Kepler data. base_dir: Base directory containing Kepler data.
kep_id: Id of the Kepler target star. May be an int or a possibly zero- kep_id: Id of the Kepler target star. May be an int or a possibly zero-
padded string. padded string.
long_cadence: Whether to read a long cadence (~29.4 min / measurement) light long_cadence: Whether to read a long cadence (~29.4 min / measurement) light
curve as opposed to a short cadence (~1 min / measurement) light curve. curve as opposed to a short cadence (~1 min / measurement) light curve.
quarters: Optional list of integers in [0, 17]; the quarters of the Kepler quarters: Optional list of integers in [0, 17]; the quarters of the Kepler
mission to return. mission to return.
injected_group: Optional string indicating injected light curves. One of injected_group: Optional string indicating injected light curves. One of
"inj1", "inj2", "inj3". "inj1", "inj2", "inj3".
check_existence: If True, only return filenames corresponding to files that check_existence: If True, only return filenames corresponding to files that
exist (not all stars have data for all quarters). exist (not all stars have data for all quarters).
Returns: Returns:
A list of filenames. A list of filenames.
......
...@@ -32,16 +32,16 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None): ...@@ -32,16 +32,16 @@ def median_filter(x, y, num_bins, bin_width=None, x_min=None, x_max=None):
Args: Args:
x: 1D array of x-coordinates sorted in ascending order. Must have at least 2 x: 1D array of x-coordinates sorted in ascending order. Must have at least 2
elements, and all elements cannot be the same value. elements, and all elements cannot be the same value.
y: 1D array of y-coordinates with the same size as x. y: 1D array of y-coordinates with the same size as x.
num_bins: The number of intervals to divide the x-axis into. Must be at num_bins: The number of intervals to divide the x-axis into. Must be at
least 2. least 2.
bin_width: The width of each bin on the x-axis. Must be positive, and less bin_width: The width of each bin on the x-axis. Must be positive, and less
than x_max - x_min. Defaults to (x_max - x_min) / num_bins. than x_max - x_min. Defaults to (x_max - x_min) / num_bins.
x_min: The inclusive leftmost value to consider on the x-axis. Must be less x_min: The inclusive leftmost value to consider on the x-axis. Must be less
than or equal to the largest value of x. Defaults to min(x). than or equal to the largest value of x. Defaults to min(x).
x_max: The exclusive rightmost value to consider on the x-axis. Must be x_max: The exclusive rightmost value to consider on the x-axis. Must be
greater than x_min. Defaults to max(x). greater than x_min. Defaults to max(x).
Returns: Returns:
1D NumPy array of size num_bins containing the median y-values of uniformly 1D NumPy array of size num_bins containing the median y-values of uniformly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment