Commit b3247557 authored by Dheera Venkatraman's avatar Dheera Venkatraman
Browse files

add flag for saving images to summary; strings moved to common.py'

parents 75c931fd 2041d5ca
...@@ -10,13 +10,14 @@ ...@@ -10,13 +10,14 @@
/research/compression/ @nmjohn /research/compression/ @nmjohn
/research/deeplab/ @aquariusjay @yknzhu @gpapan /research/deeplab/ @aquariusjay @yknzhu @gpapan
/research/delf/ @andrefaraujo /research/delf/ @andrefaraujo
/research/differential_privacy/ @panyx0718 /research/differential_privacy/ @panyx0718 @mironov
/research/domain_adaptation/ @bousmalis @dmrd /research/domain_adaptation/ @bousmalis @dmrd
/research/gan/ @joel-shor /research/gan/ @joel-shor
/research/im2txt/ @cshallue /research/im2txt/ @cshallue
/research/inception/ @shlens @vincentvanhoucke /research/inception/ @shlens @vincentvanhoucke
/research/learned_optimizer/ @olganw @nirum /research/learned_optimizer/ @olganw @nirum
/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum /research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
/research/learning_unsupervised_learning/ @lukemetz @nirum
/research/lexnet_nc/ @vered1986 @waterson /research/lexnet_nc/ @vered1986 @waterson
/research/lfads/ @jazcollins @susillo /research/lfads/ @jazcollins @susillo
/research/lm_1b/ @oriolvinyals @panyx0718 /research/lm_1b/ @oriolvinyals @panyx0718
......
...@@ -24,7 +24,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -24,7 +24,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
......
...@@ -31,8 +31,8 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -31,8 +31,8 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.export import export from official.utils.export import export
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logging import logger from official.utils.logs import logger
################################################################################ ################################################################################
......
...@@ -132,7 +132,7 @@ class BaseParser(argparse.ArgumentParser): ...@@ -132,7 +132,7 @@ class BaseParser(argparse.ArgumentParser):
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. " "Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, " "Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook." "ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook."
"See official.utils.logging.hooks_helper for details.", "See official.utils.logs.hooks_helper for details.",
metavar="<HK>" metavar="<HK>"
) )
......
...@@ -31,10 +31,10 @@ import uuid ...@@ -31,10 +31,10 @@ import uuid
from google.cloud import bigquery from google.cloud import bigquery
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logging import logger from official.utils.logs import logger
class BigQueryUploader(object): class BigQueryUploader(object):
......
...@@ -26,8 +26,8 @@ from __future__ import print_function ...@@ -26,8 +26,8 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks from official.utils.logs import hooks
from official.utils.logging import metric_hook from official.utils.logs import metric_hook
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate', _TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'cross_entropy', 'cross_entropy',
......
...@@ -23,7 +23,7 @@ import unittest ...@@ -23,7 +23,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
class BaseTest(unittest.TestCase): class BaseTest(unittest.TestCase):
......
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.utils.logging import hooks from official.utils.logs import hooks
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
......
...@@ -26,7 +26,7 @@ import unittest ...@@ -26,7 +26,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger from official.utils.logs import logger
class BenchmarkLoggerTest(tf.test.TestCase): class BenchmarkLoggerTest(tf.test.TestCase):
......
...@@ -18,9 +18,9 @@ from __future__ import absolute_import ...@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger from official.utils.logs import logger
class LoggingMetricHook(tf.train.LoggingTensorHook): class LoggingMetricHook(tf.train.LoggingTensorHook):
......
...@@ -24,15 +24,17 @@ import time ...@@ -24,15 +24,17 @@ import time
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from official.utils.logging import metric_hook from official.utils.logs import metric_hook # pylint: disable=g-bad-import-order
class LoggingMetricHookTest(tf.test.TestCase): class LoggingMetricHookTest(tf.test.TestCase):
"""Tests for LoggingMetricHook."""
def setUp(self): def setUp(self):
super(LoggingMetricHookTest, self).setUp() super(LoggingMetricHookTest, self).setUp()
class MockMetricLogger(object): class MockMetricLogger(object):
def __init__(self): def __init__(self):
self.logged_metric = [] self.logged_metric = []
......
...@@ -53,12 +53,14 @@ py_test() { ...@@ -53,12 +53,14 @@ py_test() {
py2_test() { py2_test() {
local PY_BINARY=$(which python2) local PY_BINARY=$(which python2)
return $(py_test "${PY_BINARY}") py_test "$PY_BINARY"
return $?
} }
py3_test() { py3_test() {
local PY_BINARY=$(which python3) local PY_BINARY=$(which python3)
return $(py_test "${PY_BINARY}") py_test "$PY_BINARY"
return $?
} }
test_result=0 test_result=0
......
...@@ -25,7 +25,7 @@ import sys ...@@ -25,7 +25,7 @@ import sys
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
_CSV_COLUMNS = [ _CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'age', 'workclass', 'fnlwgt', 'education', 'education_num',
......
...@@ -28,14 +28,16 @@ installation](https://www.tensorflow.org/install). ...@@ -28,14 +28,16 @@ installation](https://www.tensorflow.org/install).
pre-trained Residual GRU network. pre-trained Residual GRU network.
- [deeplab](deeplab): deep labelling for semantic image segmentation. - [deeplab](deeplab): deep labelling for semantic image segmentation.
- [delf](delf): deep local features for image matching and retrieval. - [delf](delf): deep local features for image matching and retrieval.
- [differential_privacy](differential_privacy): privacy-preserving student - [differential_privacy](differential_privacy): differential privacy for training
models from multiple teachers. data.
- [domain_adaptation](domain_adaptation): domain separation networks. - [domain_adaptation](domain_adaptation): domain separation networks.
- [gan](gan): generative adversarial networks. - [gan](gan): generative adversarial networks.
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision. - [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a - [learning_to_remember_rare_events](learning_to_remember_rare_events): a
large-scale life-long memory module for use in deep learning. large-scale life-long memory module for use in deep learning.
- [learning_unsupervised_learning](learning_unsupervised_learning): a
meta-learned unsupervised learning update rule.
- [lexnet_nc](lexnet_nc): a distributed model for noun compound relationship - [lexnet_nc](lexnet_nc): a distributed model for noun compound relationship
classification. classification.
- [lfads](lfads): sequential variational autoencoder for analyzing - [lfads](lfads): sequential variational autoencoder for analyzing
......
...@@ -73,7 +73,7 @@ def define(): ...@@ -73,7 +73,7 @@ def define():
flags.DEFINE_string('optimizer', 'momentum', flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use') 'the optimizer to use')
flags.DEFINE_string('momentum', 0.9, flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used') 'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True, flags.DEFINE_bool('use_augment_input', True,
......
...@@ -71,7 +71,7 @@ def main(_): ...@@ -71,7 +71,7 @@ def main(_):
return return
contents = '' contents = ''
with tf.gfile.FastGFile(FLAGS.input_codes, 'r') as code_file: with tf.gfile.FastGFile(FLAGS.input_codes, 'rb') as code_file:
contents = code_file.read() contents = code_file.read()
loaded_codes = np.load(io.BytesIO(contents)) loaded_codes = np.load(io.BytesIO(contents))
assert ['codes', 'shape'] not in loaded_codes.files assert ['codes', 'shape'] not in loaded_codes.files
......
...@@ -59,7 +59,7 @@ def main(_): ...@@ -59,7 +59,7 @@ def main(_):
print('\n--iteration must be between 0 and 15 inclusive.\n') print('\n--iteration must be between 0 and 15 inclusive.\n')
return return
with tf.gfile.FastGFile(FLAGS.input_image) as input_image: with tf.gfile.FastGFile(FLAGS.input_image, 'rb') as input_image:
input_image_str = input_image.read() input_image_str = input_image.read()
with tf.Graph().as_default() as graph: with tf.Graph().as_default() as graph:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment