"vscode:/vscode.git/clone" did not exist on "91c72cdf520dc9d2b4614771a6fdc98f29f250e2"
Commit d5fc3ef0 authored by pkulzc's avatar pkulzc
Browse files

Merge remote-tracking branch 'upstream/master'

parents 6b72b5cd 57b99319
......@@ -10,7 +10,7 @@
/research/compression/ @nmjohn
/research/deeplab/ @aquariusjay @yknzhu @gpapan
/research/delf/ @andrefaraujo
/research/differential_privacy/ @panyx0718
/research/differential_privacy/ @panyx0718 @mironov
/research/domain_adaptation/ @bousmalis @dmrd
/research/gan/ @joel-shor
/research/im2txt/ @cshallue
......
......@@ -24,7 +24,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper
from official.utils.logs import hooks_helper
LEARNING_RATE = 1e-4
......
......@@ -107,7 +107,7 @@ def main(argv):
(device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value.
if flags.data_format is not None:
data_format = data_format
data_format = flags.data_format
print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets
......
......@@ -31,8 +31,8 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model
from official.utils.arg_parsers import parsers
from official.utils.export import export
from official.utils.logging import hooks_helper
from official.utils.logging import logger
from official.utils.logs import hooks_helper
from official.utils.logs import logger
################################################################################
......
......@@ -132,7 +132,7 @@ class BaseParser(argparse.ArgumentParser):
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook."
"See official.utils.logging.hooks_helper for details.",
"See official.utils.logs.hooks_helper for details.",
metavar="<HK>"
)
......
......@@ -34,7 +34,7 @@ from google.cloud import bigquery
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers
from official.utils.logging import logger
from official.utils.logs import logger
class BigQueryUploader(object):
......
......@@ -26,8 +26,8 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks
from official.utils.logging import metric_hook
from official.utils.logs import hooks
from official.utils.logs import metric_hook
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
'cross_entropy',
......
......@@ -23,7 +23,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper
from official.utils.logs import hooks_helper
class BaseTest(unittest.TestCase):
......
......@@ -24,7 +24,7 @@ import time
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from official.utils.logging import hooks
from official.utils.logs import hooks
tf.logging.set_verbosity(tf.logging.ERROR)
......
......@@ -26,7 +26,7 @@ import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger
from official.utils.logs import logger
class BenchmarkLoggerTest(tf.test.TestCase):
......
......@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import logger
from official.utils.logs import logger
class LoggingMetricHook(tf.train.LoggingTensorHook):
......
......@@ -24,15 +24,17 @@ import time
import tensorflow as tf
from tensorflow.python.training import monitored_session
from official.utils.logging import metric_hook
from official.utils.logs import metric_hook # pylint: disable=g-bad-import-order
class LoggingMetricHookTest(tf.test.TestCase):
"""Tests for LoggingMetricHook."""
def setUp(self):
super(LoggingMetricHookTest, self).setUp()
class MockMetricLogger(object):
def __init__(self):
self.logged_metric = []
......
......@@ -25,7 +25,7 @@ import sys
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper
from official.utils.logs import hooks_helper
_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
......
......@@ -28,8 +28,8 @@ installation](https://www.tensorflow.org/install).
pre-trained Residual GRU network.
- [deeplab](deeplab): deep labelling for semantic image segmentation.
- [delf](delf): deep local features for image matching and retrieval.
- [differential_privacy](differential_privacy): privacy-preserving student
models from multiple teachers.
- [differential_privacy](differential_privacy): differential privacy for training
data.
- [domain_adaptation](domain_adaptation): domain separation networks.
- [gan](gan): generative adversarial networks.
- [im2txt](im2txt): image-to-text neural network for image captioning.
......
......@@ -73,7 +73,7 @@ def define():
flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use')
flags.DEFINE_string('momentum', 0.9,
flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True,
......
......@@ -95,7 +95,6 @@ ORIGINAL_IMAGE = 'original_image'
# Test set name.
TEST_SET = 'test'
class ModelOptions(
collections.namedtuple('ModelOptions', [
'outputs_to_num_classes',
......
......@@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for xception.py."""
import six
import numpy as np
import tensorflow as tf
......@@ -309,7 +310,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 14, 14, 4],
'xception/exit_flow/block1': [2, 7, 7, 8],
'xception/exit_flow/block2': [2, 7, 7, 16]}
for endpoint, shape in endpoint_to_shape.iteritems():
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testFullyConvolutionalEndpointShapes(self):
......@@ -330,7 +331,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 21, 21, 4],
'xception/exit_flow/block1': [2, 11, 11, 8],
'xception/exit_flow/block2': [2, 11, 11, 16]}
for endpoint, shape in endpoint_to_shape.iteritems():
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalEndpointShapes(self):
......@@ -352,7 +353,7 @@ class XceptionNetworkTest(tf.test.TestCase):
'xception/middle_flow/block1': [2, 41, 41, 4],
'xception/exit_flow/block1': [2, 41, 41, 8],
'xception/exit_flow/block2': [2, 41, 41, 16]}
for endpoint, shape in endpoint_to_shape.iteritems():
for endpoint, shape in six.iteritems(endpoint_to_shape):
self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
def testAtrousFullyConvolutionalValues(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment