Commit d5fc3ef0 authored by pkulzc's avatar pkulzc
Browse files

Merge remote-tracking branch 'upstream/master'

parents 6b72b5cd 57b99319
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
/research/compression/ @nmjohn /research/compression/ @nmjohn
/research/deeplab/ @aquariusjay @yknzhu @gpapan /research/deeplab/ @aquariusjay @yknzhu @gpapan
/research/delf/ @andrefaraujo /research/delf/ @andrefaraujo
/research/differential_privacy/ @panyx0718 /research/differential_privacy/ @panyx0718 @mironov
/research/domain_adaptation/ @bousmalis @dmrd /research/domain_adaptation/ @bousmalis @dmrd
/research/gan/ @joel-shor /research/gan/ @joel-shor
/research/im2txt/ @cshallue /research/im2txt/ @cshallue
......
...@@ -24,7 +24,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -24,7 +24,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
......
...@@ -107,7 +107,7 @@ def main(argv): ...@@ -107,7 +107,7 @@ def main(argv):
(device, data_format) = ('/cpu:0', 'channels_last') (device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value. # If data_format is defined in FLAGS, overwrite automatically set value.
if flags.data_format is not None: if flags.data_format is not None:
data_format = data_format data_format = flags.data_format
print('Using device %s, and data format %s.' % (device, data_format)) print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets # Load the datasets
......
...@@ -31,8 +31,8 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -31,8 +31,8 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.export import export from official.utils.export import export
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logging import logger from official.utils.logs import logger
################################################################################ ################################################################################
......
...@@ -132,7 +132,7 @@ class BaseParser(argparse.ArgumentParser): ...@@ -132,7 +132,7 @@ class BaseParser(argparse.ArgumentParser):
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. " "Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, " "Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook." "ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook."
"See official.utils.logging.hooks_helper for details.", "See official.utils.logs.hooks_helper for details.",
metavar="<HK>" metavar="<HK>"
) )
......
...@@ -31,10 +31,10 @@ import uuid ...@@ -31,10 +31,10 @@ import uuid
from google.cloud import bigquery from google.cloud import bigquery
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logging import logger from official.utils.logs import logger
class BigQueryUploader(object): class BigQueryUploader(object):
......
...@@ -26,8 +26,8 @@ from __future__ import print_function ...@@ -26,8 +26,8 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks from official.utils.logs import hooks
from official.utils.logging import metric_hook from official.utils.logs import metric_hook
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate', _TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'cross_entropy', 'cross_entropy',
......
...@@ -23,7 +23,7 @@ import unittest ...@@ -23,7 +23,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
class BaseTest(unittest.TestCase): class BaseTest(unittest.TestCase):
......
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.utils.logging import hooks from official.utils.logs import hooks
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
......
...@@ -26,7 +26,7 @@ import unittest ...@@ -26,7 +26,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger from official.utils.logs import logger
class BenchmarkLoggerTest(tf.test.TestCase): class BenchmarkLoggerTest(tf.test.TestCase):
......
...@@ -18,9 +18,9 @@ from __future__ import absolute_import ...@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger from official.utils.logs import logger
class LoggingMetricHook(tf.train.LoggingTensorHook): class LoggingMetricHook(tf.train.LoggingTensorHook):
......
...@@ -24,15 +24,17 @@ import time ...@@ -24,15 +24,17 @@ import time
import tensorflow as tf import tensorflow as tf
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from official.utils.logging import metric_hook from official.utils.logs import metric_hook # pylint: disable=g-bad-import-order
class LoggingMetricHookTest(tf.test.TestCase): class LoggingMetricHookTest(tf.test.TestCase):
"""Tests for LoggingMetricHook."""
def setUp(self): def setUp(self):
super(LoggingMetricHookTest, self).setUp() super(LoggingMetricHookTest, self).setUp()
class MockMetricLogger(object): class MockMetricLogger(object):
def __init__(self): def __init__(self):
self.logged_metric = [] self.logged_metric = []
......
...@@ -25,7 +25,7 @@ import sys ...@@ -25,7 +25,7 @@ import sys
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper from official.utils.logs import hooks_helper
_CSV_COLUMNS = [ _CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'age', 'workclass', 'fnlwgt', 'education', 'education_num',
......
...@@ -28,8 +28,8 @@ installation](https://www.tensorflow.org/install). ...@@ -28,8 +28,8 @@ installation](https://www.tensorflow.org/install).
pre-trained Residual GRU network. pre-trained Residual GRU network.
- [deeplab](deeplab): deep labelling for semantic image segmentation. - [deeplab](deeplab): deep labelling for semantic image segmentation.
- [delf](delf): deep local features for image matching and retrieval. - [delf](delf): deep local features for image matching and retrieval.
- [differential_privacy](differential_privacy): privacy-preserving student - [differential_privacy](differential_privacy): differential privacy for training
models from multiple teachers. data.
- [domain_adaptation](domain_adaptation): domain separation networks. - [domain_adaptation](domain_adaptation): domain separation networks.
- [gan](gan): generative adversarial networks. - [gan](gan): generative adversarial networks.
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
......
...@@ -73,7 +73,7 @@ def define(): ...@@ -73,7 +73,7 @@ def define():
flags.DEFINE_string('optimizer', 'momentum', flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use') 'the optimizer to use')
flags.DEFINE_string('momentum', 0.9, flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used') 'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True, flags.DEFINE_bool('use_augment_input', True,
......
...@@ -95,7 +95,6 @@ ORIGINAL_IMAGE = 'original_image' ...@@ -95,7 +95,6 @@ ORIGINAL_IMAGE = 'original_image'
# Test set name. # Test set name.
TEST_SET = 'test' TEST_SET = 'test'
class ModelOptions( class ModelOptions(
collections.namedtuple('ModelOptions', [ collections.namedtuple('ModelOptions', [
'outputs_to_num_classes', 'outputs_to_num_classes',
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for xception.py.""" """Tests for xception.py."""
import six
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -309,7 +310,7 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -309,7 +310,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 14, 14, 4], 'xception/middle_flow/block1': [2, 14, 14, 4],
'xception/exit_flow/block1': [2, 7, 7, 8], 'xception/exit_flow/block1': [2, 7, 7, 8],
'xception/exit_flow/block2': [2, 7, 7, 16]} 'xception/exit_flow/block2': [2, 7, 7, 16]}
for endpoint, shape in endpoint_to_shape.iteritems(): for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self): def testFullyConvolutionalEndpointShapes(self):
...@@ -330,7 +331,7 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -330,7 +331,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 21, 21, 4], 'xception/middle_flow/block1': [2, 21, 21, 4],
'xception/exit_flow/block1': [2, 11, 11, 8], 'xception/exit_flow/block1': [2, 11, 11, 8],
'xception/exit_flow/block2': [2, 11, 11, 16]} 'xception/exit_flow/block2': [2, 11, 11, 16]}
for endpoint, shape in endpoint_to_shape.iteritems(): for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self): def testAtrousFullyConvolutionalEndpointShapes(self):
...@@ -352,7 +353,7 @@ class XceptionNetworkTest(tf.test.TestCase): ...@@ -352,7 +353,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 41, 41, 4], 'xception/middle_flow/block1': [2, 41, 41, 4],
'xception/exit_flow/block1': [2, 41, 41, 8], 'xception/exit_flow/block1': [2, 41, 41, 8],
'xception/exit_flow/block2': [2, 41, 41, 16]} 'xception/exit_flow/block2': [2, 41, 41, 16]}
for endpoint, shape in endpoint_to_shape.iteritems(): for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape) self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValues(self): def testAtrousFullyConvolutionalValues(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment