"examples/pytorch/mock_sparse/gcn/README.md" did not exist on "e557ed89ff0aefeb2a5ffbc601b39ee88b16f1e3"
Unverified Commit 6571d16d authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3544 from cshallue/master

Add AstroNet to tensorflow/models
parents 92083555 6c891bc3
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "configdict",
srcs = ["configdict.py"],
srcs_version = "PY2AND3",
deps = [
],
)
py_test(
name = "configdict_test",
size = "small",
srcs = ["configdict_test.py"],
srcs_version = "PY2AND3",
deps = [":configdict"],
)
py_library(
name = "config_util",
srcs = ["config_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "config_util_test",
size = "small",
srcs = ["config_util_test.py"],
srcs_version = "PY2AND3",
deps = [":config_util"],
)
py_library(
name = "estimator_util",
srcs = ["estimator_util.py"],
srcs_version = "PY2AND3",
deps = [
"//astronet/ops:dataset_ops",
"//astronet/ops:metrics",
"//astronet/ops:training",
],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
import tensorflow as tf
def parse_json(json_string_or_file):
"""Parses values from a JSON string or JSON file.
This function is useful for command line flags containing configuration
overrides. Using this function, the flag can be passed either as a JSON string
(e.g. '{"learning_rate": 1.0}') or the path to a JSON configuration file.
Args:
json_string_or_file: A JSON serialized string OR the path to a JSON file.
Returns:
A dictionary; the parsed JSON.
Raises:
ValueError: If the JSON could not be parsed.
"""
# First, attempt to parse the string as a JSON dict.
try:
json_dict = json.loads(json_string_or_file)
except ValueError as literal_json_parsing_error:
try:
# Otherwise, try to use it as a path to a JSON file.
with tf.gfile.Open(json_string_or_file) as f:
json_dict = json.load(f)
except ValueError as json_file_parsing_error:
raise ValueError("Unable to parse the content of the json file %s. "
"Parsing error: %s." % (json_string_or_file,
json_file_parsing_error.message))
except tf.gfile.FileError:
message = ("Unable to parse the input parameter neither as literal "
"JSON nor as the name of a file that exists.\n"
"JSON parsing error: %s\n\n Input parameter:\n%s." %
(literal_json_parsing_error.message, json_string_or_file))
raise ValueError(message)
return json_dict
def log_and_save_config(config, output_dir):
"""Logs and writes a JSON-serializable configuration object.
Args:
config: A JSON-serializable object.
output_dir: Destination directory.
"""
if hasattr(config, "to_json") and callable(config.to_json):
config_json = config.to_json(indent=2)
else:
config_json = json.dumps(config, indent=2)
tf.logging.info("config: %s", config_json)
tf.gfile.MakeDirs(output_dir)
with tf.gfile.Open(os.path.join(output_dir, "config.json"), "w") as f:
f.write(config_json)
def unflatten(flat_config):
"""Transforms a flat configuration dictionary into a nested dictionary.
Example:
{
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.d.f": 4,
}
would be transformed to:
{
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3,
"f": 4,
}
}
}
Args:
flat_config: A dictionary with strings as keys where nested configuration
parameters are represented with period-separated names.
Returns:
A dictionary nested according to the keys of the input dictionary.
"""
config = {}
for path, value in flat_config.iteritems():
path = path.split(".")
final_key = path.pop()
nested_config = config
for key in path:
nested_config = nested_config.setdefault(key, {})
nested_config[final_key] = value
return config
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from astronet.util import config_util
class ConfigUtilTest(tf.test.TestCase):
def testUnflatten(self):
# Empty dict.
self.assertDictEqual(config_util.unflatten({}), {})
# Already flat dict.
self.assertDictEqual(
config_util.unflatten({
"a": 1,
"b": 2
}), {
"a": 1,
"b": 2
})
# Nested dict.
self.assertDictEqual(
config_util.unflatten({
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.d.f": 4,
}), {
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3,
"f": 4,
}
}
})
if __name__ == "__main__":
tf.test.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration container for TensorFlow models.
A ConfigDict is simply a dict whose values can be accessed via both dot syntax
(config.key) and dict syntax (config['key']).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def _maybe_convert_dict(value):
if isinstance(value, dict):
return ConfigDict(value)
return value
class ConfigDict(dict):
"""Configuration container class."""
def __init__(self, initial_dictionary=None):
"""Creates an instance of ConfigDict.
Args:
initial_dictionary: Optional dictionary or ConfigDict containing initial
parameters.
"""
if initial_dictionary:
for field, value in initial_dictionary.iteritems():
initial_dictionary[field] = _maybe_convert_dict(value)
super(ConfigDict, self).__init__(initial_dictionary)
def __setattr__(self, attribute, value):
self[attribute] = _maybe_convert_dict(value)
def __getattr__(self, attribute):
try:
return self[attribute]
except KeyError as e:
raise AttributeError(e)
def __delattr__(self, attribute):
try:
del self[attribute]
except KeyError as e:
raise AttributeError(e)
def __setitem__(self, key, value):
super(ConfigDict, self).__setitem__(key, _maybe_convert_dict(value))
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config_util.configdict."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from astronet.util import configdict
class ConfigDictTest(absltest.TestCase):
def setUp(self):
super(ConfigDictTest, self).setUp()
self._config = configdict.ConfigDict({
"int": 1,
"float": 2.0,
"bool": True,
"str": "hello",
"nested": {
"int": 3,
},
"double_nested": {
"a": {
"int": 3,
},
"b": {
"float": 4.0,
}
}
})
def testAccess(self):
# Simple types.
self.assertEqual(1, self._config.int)
self.assertEqual(1, self._config["int"])
self.assertEqual(2.0, self._config.float)
self.assertEqual(2.0, self._config["float"])
self.assertTrue(self._config.bool)
self.assertTrue(self._config["bool"])
self.assertEqual("hello", self._config.str)
self.assertEqual("hello", self._config["str"])
# Single nested config.
self.assertEqual(3, self._config.nested.int)
self.assertEqual(3, self._config["nested"].int)
self.assertEqual(3, self._config.nested["int"])
self.assertEqual(3, self._config["nested"]["int"])
# Double nested config.
self.assertEqual(3, self._config["double_nested"].a.int)
self.assertEqual(3, self._config["double_nested"]["a"].int)
self.assertEqual(3, self._config["double_nested"].a["int"])
self.assertEqual(3, self._config["double_nested"]["a"]["int"])
self.assertEqual(4.0, self._config.double_nested.b.float)
self.assertEqual(4.0, self._config.double_nested["b"].float)
self.assertEqual(4.0, self._config.double_nested.b["float"])
self.assertEqual(4.0, self._config.double_nested["b"]["float"])
# Nonexistent parameters.
with self.assertRaises(AttributeError):
_ = self._config.nonexistent
with self.assertRaises(KeyError):
_ = self._config["nonexistent"]
def testSetAttribut(self):
# Overwrite existing simple type.
self._config.int = 40
self.assertEqual(40, self._config.int)
# Overwrite existing nested simple type.
self._config.nested.int = 40
self.assertEqual(40, self._config.nested.int)
# Overwrite existing nested config.
self._config.double_nested.a = {"float": 50.0}
self.assertIsInstance(self._config.double_nested.a, configdict.ConfigDict)
self.assertEqual(50.0, self._config.double_nested.a.float)
self.assertNotIn("int", self._config.double_nested.a)
# Set new simple type.
self._config.int_2 = 10
self.assertEqual(10, self._config.int_2)
# Set new nested simple type.
self._config.nested.int_2 = 20
self.assertEqual(20, self._config.nested.int_2)
# Set new nested config.
self._config.double_nested.c = {"int": 30}
self.assertIsInstance(self._config.double_nested.c, configdict.ConfigDict)
self.assertEqual(30, self._config.double_nested.c.int)
def testSetItem(self):
# Overwrite existing simple type.
self._config["int"] = 40
self.assertEqual(40, self._config.int)
# Overwrite existing nested simple type.
self._config["nested"].int = 40
self.assertEqual(40, self._config.nested.int)
self._config.nested["int"] = 50
self.assertEqual(50, self._config.nested.int)
# Overwrite existing nested config.
self._config.double_nested["a"] = {"float": 50.0}
self.assertIsInstance(self._config.double_nested.a, configdict.ConfigDict)
self.assertEqual(50.0, self._config.double_nested.a.float)
self.assertNotIn("int", self._config.double_nested.a)
# Set new simple type.
self._config["int_2"] = 10
self.assertEqual(10, self._config.int_2)
# Set new nested simple type.
self._config.nested["int_2"] = 20
self.assertEqual(20, self._config.nested.int_2)
self._config.nested["int_3"] = 30
self.assertEqual(30, self._config.nested.int_3)
# Set new nested config.
self._config.double_nested["c"] = {"int": 30}
self.assertIsInstance(self._config.double_nested.c, configdict.ConfigDict)
self.assertEqual(30, self._config.double_nested.c.int)
def testDelete(self):
# Simple types.
self.assertEqual(1, self._config.int)
del self._config.int
with self.assertRaises(AttributeError):
_ = self._config.int
with self.assertRaises(KeyError):
_ = self._config["int"]
self.assertEqual(2.0, self._config["float"])
del self._config["float"]
with self.assertRaises(AttributeError):
_ = self._config.float
with self.assertRaises(KeyError):
_ = self._config["float"]
# Nested config.
self.assertEqual(3, self._config.nested.int)
del self._config.nested
with self.assertRaises(AttributeError):
_ = self._config.nested
with self.assertRaises(KeyError):
_ = self._config["nested"]
if __name__ == "__main__":
absltest.main()
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for training models with the TensorFlow Estimator API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
from astronet.ops import dataset_ops
from astronet.ops import metrics
from astronet.ops import training
def create_input_fn(file_pattern,
input_config,
mode,
shuffle_values_buffer=0,
repeat=1):
"""Creates an input_fn that reads a dataset from sharded TFRecord files.
Args:
file_pattern: File pattern matching input TFRecord files, e.g.
"/tmp/train-?????-of-00100". May also be a comma-separated list of file
patterns.
input_config: ConfigDict containing feature and label specifications.
mode: A tf.estimator.ModeKeys.
shuffle_values_buffer: If > 0, shuffle examples using a buffer of this size.
repeat: The number of times to repeat the dataset. If None or -1 the
elements will be repeated indefinitely.
Returns:
A callable that builds an input pipeline and returns (features, labels).
"""
include_labels = (
mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL])
reverse_time_series_prob = 0.5 if mode == tf.estimator.ModeKeys.TRAIN else 0
shuffle_filenames = (mode == tf.estimator.ModeKeys.TRAIN)
def input_fn(config, params):
"""Builds an input pipeline that reads a dataset from TFRecord files."""
# Infer whether this input_fn was called by Estimator or TPUEstimator using
# the config type.
use_tpu = isinstance(config, tf.contrib.tpu.RunConfig)
dataset = dataset_ops.build_dataset(
file_pattern=file_pattern,
input_config=input_config,
batch_size=params["batch_size"],
include_labels=include_labels,
reverse_time_series_prob=reverse_time_series_prob,
shuffle_filenames=shuffle_filenames,
shuffle_values_buffer=shuffle_values_buffer,
repeat=repeat,
use_tpu=use_tpu)
# We must use an initializable iterator, rather than a one-shot iterator,
# because the input pipeline contains a stateful table that requires
# initialization. We add the initializer to the TABLE_INITIALIZERS
# collection to ensure it is run during initialization.
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
inputs = iterator.get_next()
return inputs, inputs.pop("labels", None)
return input_fn
def create_model_fn(model_class, hparams, use_tpu=False):
"""Wraps model_class as an Estimator or TPUEstimator model_fn.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
use_tpu: If True, a TPUEstimator model_fn is returned. Otherwise an
Estimator model_fn is returned.
Returns:
model_fn: A callable that constructs the model and returns a
TPUEstimatorSpec if use_tpu is True, otherwise an EstimatorSpec.
"""
hparams = copy.deepcopy(hparams)
def model_fn(features, labels, mode, params):
"""Builds the model and returns an EstimatorSpec or TPUEstimatorSpec."""
# For TPUEstimator, params contains the batch size per TPU core.
if "batch_size" in params:
hparams.batch_size = params["batch_size"]
model = model_class(features, labels, hparams, mode)
model.build()
# Possibly create train_op.
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
learning_rate = training.create_learning_rate(hparams, model.global_step)
optimizer = training.create_optimizer(hparams, learning_rate, use_tpu)
train_op = training.create_train_op(model, optimizer)
# Possibly create evaluation metrics.
eval_metrics = None
if mode == tf.estimator.ModeKeys.EVAL:
eval_metrics = (
metrics.create_metric_fn(model)
if use_tpu else metrics.create_metrics(model))
if use_tpu:
estimator = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
predictions=model.predictions,
loss=model.total_loss,
train_op=train_op,
eval_metrics=eval_metrics)
else:
estimator = tf.estimator.EstimatorSpec(
mode=mode,
predictions=model.predictions,
loss=model.total_loss,
train_op=train_op,
eval_metric_ops=eval_metrics)
return estimator
return model_fn
def create_estimator(model_class,
hparams,
run_config=None,
model_dir=None,
eval_batch_size=None):
"""Wraps model_class as an Estimator or TPUEstimator.
If run_config is None or a tf.estimator.RunConfig, an Estimator is returned.
If run_config is a tf.contrib.tpu.RunConfig, a TPUEstimator is returned.
Args:
model_class: AstroModel or a subclass.
hparams: ConfigDict of configuration parameters for building the model.
run_config: Optional tf.estimator.RunConfig or tf.contrib.tpu.RunConfig.
model_dir: Optional directory for saving the model. If not passed
explicitly, it must be specified in run_config.
eval_batch_size: Optional batch size for evaluation on TPU. Only applicable
if run_config is a tf.contrib.tpu.RunConfig. Defaults to
hparams.batch_size.
Returns:
An Estimator object if run_config is None or a tf.estimator.RunConfig, or a
TPUEstimator object if run_config is a tf.contrib.tpu.RunConfig.
Raises:
ValueError:
If model_dir is not passed explicitly or in run_config.model_dir, or if
eval_batch_size is specified and run_config is not a
tf.contrib.tpu.RunConfig.
"""
if run_config is None:
run_config = tf.estimator.RunConfig()
else:
run_config = copy.deepcopy(run_config)
if not model_dir and not run_config.model_dir:
raise ValueError(
"model_dir must be passed explicitly or specified in run_config")
use_tpu = isinstance(run_config, tf.contrib.tpu.RunConfig)
model_fn = create_model_fn(model_class, hparams, use_tpu)
if use_tpu:
eval_batch_size = eval_batch_size or hparams.batch_size
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
train_batch_size=hparams.batch_size,
eval_batch_size=eval_batch_size)
else:
if eval_batch_size is not None:
raise ValueError("eval_batch_size can only be specified for TPU.")
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
config=run_config,
params={"batch_size": hparams.batch_size})
return estimator
def evaluate(estimator, input_fn, eval_steps=None, eval_name="val"):
"""Runs evaluation on the latest model checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Returns:
A dict of metric values from the evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
values = {} # Default return value if evaluation fails.
latest_checkpoint = tf.train.latest_checkpoint(estimator.model_dir)
if not latest_checkpoint:
# This is expected if the training job has not yet saved a checkpoint.
return values
tf.logging.info("Starting evaluation on checkpoint %s", latest_checkpoint)
try:
values = estimator.evaluate(input_fn, steps=eval_steps, name=eval_name)
except tf.errors.NotFoundError:
# Expected under some conditions, e.g. TPU worker does not finish
# initializing until long after the CPU job tells it to start evaluating
# and the checkpoint file is deleted already.
tf.logging.info("Checkpoint %s no longer exists, skipping evaluation",
latest_checkpoint)
return values
def continuous_eval(estimator,
input_fn,
train_steps=None,
eval_steps=None,
eval_name="val"):
"""Runs evaluation whenever there's a new checkpoint.
Args:
estimator: Instance of tf.Estimator.
input_fn: Input function returning a tuple (features, labels).
train_steps: The number of steps the model will train for. This function
will terminate once the model has finished training. If None, this
function will run forever.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
for _ in tf.contrib.training.checkpoints_iterator(estimator.model_dir):
values = evaluate(estimator, input_fn, eval_steps, eval_name)
yield values
global_step = values.get("global_step", 0)
if train_steps and global_step >= train_steps:
break
def continuous_train_and_eval(estimator,
train_input_fn,
eval_input_fn,
local_eval_frequency=None,
train_hooks=None,
train_steps=None,
eval_steps=None,
eval_name="val"):
"""Alternates training and evaluation.
Args:
estimator: Instance of tf.Estimator.
train_input_fn: Input function returning a tuple (features, labels).
eval_input_fn: Input function returning a tuple (features, labels).
local_eval_frequency: The number of training steps between evaluations. If
None, trains until train_input_fn raises an end-of-input exception.
train_hooks: List of SessionRunHook subclass instances. Used for callbacks
inside the training call.
train_steps: The total number of steps to train the model for.
eval_steps: The number of steps for which to evaluate the model. If None,
evaluates until eval_input_fn raises an end-of-input exception.
eval_name: Name of the evaluation set, e.g. "train" or "val".
Yields:
A dict of metric values from each evaluation. May be empty, e.g. if the
training job has not yet saved a checkpoint or the checkpoint is deleted by
the time the TPU worker initializes.
"""
while True:
# We run evaluation before training in this loop to prevent evaluation from
# being skipped if the process is interrupted.
values = evaluate(estimator, eval_input_fn, eval_steps, eval_name)
yield values
global_step = values.get("global_step", 0)
if train_steps and global_step >= train_steps:
break
# Decide how many steps before the next evaluation.
steps = local_eval_frequency
if train_steps:
remaining_steps = train_steps - global_step
steps = min(steps, remaining_steps) if steps else remaining_steps
tf.logging.info("Starting training at global step %d", global_step)
estimator.train(train_input_fn, hooks=train_hooks, steps=steps)
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "kepler_io",
srcs = ["kepler_io.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "kepler_io_test",
size = "small",
srcs = ["kepler_io_test.py"],
data = glob([
"test_data/0114/011442793/kplr*.fits",
]),
srcs_version = "PY2AND3",
deps = [":kepler_io"],
)
py_library(
name = "median_filter",
srcs = ["median_filter.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "median_filter_test",
size = "small",
srcs = ["median_filter_test.py"],
srcs_version = "PY2AND3",
deps = [":median_filter"],
)
py_library(
name = "periodic_event",
srcs = ["periodic_event.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "periodic_event_test",
size = "small",
srcs = ["periodic_event_test.py"],
srcs_version = "PY2AND3",
deps = [":periodic_event"],
)
py_library(
name = "util",
srcs = ["util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "util_test",
size = "small",
srcs = ["util_test.py"],
srcs_version = "PY2AND3",
deps = [
":periodic_event",
":util",
],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
cc_library(
name = "median",
hdrs = ["median.h"],
)
cc_test(
name = "median_test",
size = "small",
srcs = [
"median_test.cc",
],
deps = [
":median",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "median_filter",
srcs = ["median_filter.cc"],
hdrs = ["median_filter.h"],
deps = [
":median",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "median_filter_test",
size = "small",
srcs = [
"median_filter_test.cc",
],
deps = [
":median_filter",
":test_util",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "phase_fold",
srcs = ["phase_fold.cc"],
hdrs = ["phase_fold.h"],
deps = ["@com_google_absl//absl/strings"],
)
cc_test(
name = "phase_fold_test",
size = "small",
srcs = [
"phase_fold_test.cc",
],
deps = [
":phase_fold",
":test_util",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "normalize",
srcs = ["normalize.cc"],
hdrs = ["normalize.h"],
deps = [
":median",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "normalize_test",
size = "small",
srcs = [
"normalize_test.cc",
],
deps = [
":normalize",
":test_util",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "view_generator",
srcs = ["view_generator.cc"],
hdrs = ["view_generator.h"],
deps = [
":median_filter",
":normalize",
":phase_fold",
"@com_google_absl//absl/memory",
],
)
cc_test(
name = "view_generator_test",
size = "small",
srcs = [
"view_generator_test.cc",
],
deps = [
":test_util",
":view_generator",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "test_util",
hdrs = ["test_util.h"],
deps = [
"@com_google_googletest//:gtest",
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment