"docs/source/en/api/models/autoencoderkl_cosmos.md" did not exist on "2dad462d9bf9890df09bfb088bf0a446c6074bec"
Unverified Commit 7cfb6bbd authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Glint everything (#3654)

* Glint everything

* Adding rcfile and pylinting

* Extra newline

* Few last lints
parent adfd5a3a
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
...@@ -17,9 +17,9 @@ from __future__ import absolute_import ...@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import gzip
import os import os
import shutil import shutil
import gzip
import numpy as np import numpy as np
from six.moves import urllib from six.moves import urllib
...@@ -36,7 +36,7 @@ def check_image_file_header(filename): ...@@ -36,7 +36,7 @@ def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset.""" """Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f: with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f) magic = read32(f)
num_images = read32(f) read32(f) # num_images, unused
rows = read32(f) rows = read32(f)
cols = read32(f) cols = read32(f)
if magic != 2051: if magic != 2051:
...@@ -52,7 +52,7 @@ def check_labels_file_header(filename): ...@@ -52,7 +52,7 @@ def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset.""" """Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f: with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f) magic = read32(f)
num_items = read32(f) read32(f) # num_items, unused
if magic != 2049: if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name)) f.name))
...@@ -77,6 +77,8 @@ def download(directory, filename): ...@@ -77,6 +77,8 @@ def download(directory, filename):
def dataset(directory, images_file, labels_file): def dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""
images_file = download(directory, images_file) images_file = download(directory, images_file)
labels_file = download(directory, labels_file) labels_file = download(directory, labels_file)
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
...@@ -28,6 +28,7 @@ from official.utils.logging import hooks_helper ...@@ -28,6 +28,7 @@ from official.utils.logging import hooks_helper
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
class Model(tf.keras.Model): class Model(tf.keras.Model):
"""Model to recognize digits in the MNIST dataset. """Model to recognize digits in the MNIST dataset.
...@@ -145,14 +146,19 @@ def model_fn(features, labels, mode, params): ...@@ -145,14 +146,19 @@ def model_fn(features, labels, mode, params):
def validate_batch_size_for_multi_gpu(batch_size): def validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of """For multi-gpu, batch-size must be a multiple of the number of GPUs.
available GPUs.
Note that this should eventually be handled by replicate_model_fn Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however, directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place. so doing the work here until that feature is in place.
Args:
batch_size: the number of examples processed in each training batch.
Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid.
""" """
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
local_device_protos = device_lib.list_local_devices() local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU']) num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
...@@ -169,7 +175,7 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -169,7 +175,7 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err) raise ValueError(err)
def main(unused_argv): def main(_):
model_function = model_fn model_function = model_fn
if FLAGS.multi_gpu: if FLAGS.multi_gpu:
...@@ -195,6 +201,8 @@ def main(unused_argv): ...@@ -195,6 +201,8 @@ def main(unused_argv):
# Set up training and evaluation input functions. # Set up training and evaluation input functions.
def train_input_fn(): def train_input_fn():
"""Prepare data for training."""
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small # randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch. # enough dataset that we can easily shuffle the full epoch.
...@@ -215,7 +223,7 @@ def main(unused_argv): ...@@ -215,7 +223,7 @@ def main(unused_argv):
FLAGS.hooks, batch_size=FLAGS.batch_size) FLAGS.hooks, batch_size=FLAGS.batch_size)
# Train and evaluate model. # Train and evaluate model.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals): for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
...@@ -231,6 +239,7 @@ def main(unused_argv): ...@@ -231,6 +239,7 @@ def main(unused_argv):
class MNISTArgParser(argparse.ArgumentParser): class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model.""" """Argument parser for running MNIST model."""
def __init__(self): def __init__(self):
super(MNISTArgParser, self).__init__(parents=[ super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(),
......
...@@ -31,11 +31,11 @@ import os ...@@ -31,11 +31,11 @@ import os
import sys import sys
import time import time
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order
from official.mnist import dataset as mnist_dataset
from official.mnist import mnist from official.mnist import mnist
from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
FLAGS = None FLAGS = None
...@@ -110,9 +110,9 @@ def main(_): ...@@ -110,9 +110,9 @@ def main(_):
print('Using device %s, and data format %s.' % (device, data_format)) print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets # Load the datasets
train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch( train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
FLAGS.batch_size) FLAGS.batch_size)
test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size) test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.Model(data_format) model = mnist.Model(data_format)
...@@ -159,11 +159,12 @@ def main(_): ...@@ -159,11 +159,12 @@ def main(_):
class MNISTEagerArgParser(argparse.ArgumentParser): class MNISTEagerArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model with eager trainng loop.""" """Argument parser for running MNIST model with eager training loop."""
def __init__(self): def __init__(self):
super(MNISTEagerArgParser, self).__init__(parents=[ super(MNISTEagerArgParser, self).__init__(parents=[
parsers.BaseParser(epochs_between_evals=False, multi_gpu=False, parsers.BaseParser(
hooks=False), epochs_between_evals=False, multi_gpu=False, hooks=False),
parsers.ImageModelParser()]) parsers.ImageModelParser()])
self.add_argument( self.add_argument(
......
...@@ -17,8 +17,8 @@ from __future__ import absolute_import ...@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order
from official.mnist import mnist from official.mnist import mnist
from official.mnist import mnist_eager from official.mnist import mnist_eager
...@@ -60,6 +60,7 @@ def evaluate(defun=False): ...@@ -60,6 +60,7 @@ def evaluate(defun=False):
class MNISTTest(tf.test.TestCase): class MNISTTest(tf.test.TestCase):
"""Run tests for MNIST eager loop."""
def test_train(self): def test_train(self):
train(defun=False) train(defun=False)
......
...@@ -17,9 +17,10 @@ from __future__ import absolute_import ...@@ -17,9 +17,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf
import time import time
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import mnist from official.mnist import mnist
BATCH_SIZE = 100 BATCH_SIZE = 100
...@@ -42,6 +43,7 @@ def make_estimator(): ...@@ -42,6 +43,7 @@ def make_estimator():
class Tests(tf.test.TestCase): class Tests(tf.test.TestCase):
"""Run tests for MNIST model."""
def test_mnist(self): def test_mnist(self):
classifier = make_estimator() classifier = make_estimator()
...@@ -57,7 +59,7 @@ class Tests(tf.test.TestCase): ...@@ -57,7 +59,7 @@ class Tests(tf.test.TestCase):
input_fn = lambda: tf.random_uniform([3, 784]) input_fn = lambda: tf.random_uniform([3, 784])
predictions_generator = classifier.predict(input_fn) predictions_generator = classifier.predict(input_fn)
for i in range(3): for _ in range(3):
predictions = next(predictions_generator) predictions = next(predictions_generator)
self.assertEqual(predictions['probabilities'].shape, (10,)) self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ()) self.assertEqual(predictions['classes'].shape, ())
...@@ -103,6 +105,7 @@ class Tests(tf.test.TestCase): ...@@ -103,6 +105,7 @@ class Tests(tf.test.TestCase):
class Benchmarks(tf.test.Benchmark): class Benchmarks(tf.test.Benchmark):
"""Simple speed benchmarking for MNIST."""
def benchmark_train_step_time(self): def benchmark_train_step_time(self):
classifier = make_estimator() classifier = make_estimator()
......
...@@ -23,7 +23,8 @@ from __future__ import absolute_import ...@@ -23,7 +23,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.mnist import mnist from official.mnist import mnist
......
...@@ -36,7 +36,7 @@ parser.add_argument( ...@@ -36,7 +36,7 @@ parser.add_argument(
help='Directory to download data and extract the tarball') help='Directory to download data and extract the tarball')
def main(unused_argv): def main(_):
"""Download and extract the tarball from Alex's website.""" """Download and extract the tarball from Alex's website."""
if not os.path.exists(FLAGS.data_dir): if not os.path.exists(FLAGS.data_dir):
os.makedirs(FLAGS.data_dir) os.makedirs(FLAGS.data_dir)
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import os import os
import sys import sys
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -127,19 +127,22 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -127,19 +127,22 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation'] num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet_run_loop.process_record_dataset(dataset, is_training, batch_size, return resnet_run_loop.process_record_dataset(
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls, dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu) examples_per_epoch=num_images, multi_gpu=multi_gpu)
def get_synth_input_fn(): def get_synth_input_fn():
return resnet_run_loop.get_synth_input_fn(_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES) return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
############################################################################### ###############################################################################
# Running the model # Running the model
############################################################################### ###############################################################################
class Cifar10Model(resnet_model.Model): class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION): version=resnet_model.DEFAULT_VERSION):
...@@ -153,6 +156,9 @@ class Cifar10Model(resnet_model.Model): ...@@ -153,6 +156,9 @@ class Cifar10Model(resnet_model.Model):
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
Raises:
ValueError: if invalid resnet_size is chosen
""" """
if resnet_size % 6 != 2: if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size) raise ValueError('resnet_size must be 6n + 2:', resnet_size)
...@@ -195,7 +201,7 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -195,7 +201,7 @@ def cifar10_model_fn(features, labels, mode, params):
# for the CIFAR-10 dataset, perhaps because the regularization prevents # for the CIFAR-10 dataset, perhaps because the regularization prevents
# overfitting on the small data set. We therefore include all vars when # overfitting on the small data set. We therefore include all vars when
# regularizing and computing loss during training. # regularizing and computing loss during training.
def loss_filter_fn(name): def loss_filter_fn(_):
return True return True
return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model, return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
......
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
from tempfile import mkstemp from tempfile import mkstemp
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main from official.resnet import cifar10_main
from official.utils.testing import integration from official.utils.testing import integration
...@@ -34,6 +34,8 @@ _NUM_CHANNELS = 3 ...@@ -34,6 +34,8 @@ _NUM_CHANNELS = 3
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet.
"""
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
...@@ -52,7 +54,7 @@ class BaseTest(tf.test.TestCase): ...@@ -52,7 +54,7 @@ class BaseTest(tf.test.TestCase):
data_file.close() data_file.close()
fake_dataset = tf.data.FixedLengthRecordDataset( fake_dataset = tf.data.FixedLengthRecordDataset(
filename, cifar10_main._RECORD_BYTES) filename, cifar10_main._RECORD_BYTES) # pylint: disable=protected-access
fake_dataset = fake_dataset.map( fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False)) lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next() image, label = fake_dataset.make_one_shot_iterator().get_next()
...@@ -133,9 +135,11 @@ class BaseTest(tf.test.TestCase): ...@@ -133,9 +135,11 @@ class BaseTest(tf.test.TestCase):
num_classes = 246 num_classes = 246
for version in (1, 2): for version in (1, 2):
model = cifar10_main.Cifar10Model(32, data_format='channels_last', model = cifar10_main.Cifar10Model(
num_classes=num_classes, version=version) 32, data_format='channels_last', num_classes=num_classes,
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS]) version=version)
fake_input = tf.random_uniform(
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True) output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import os import os
import sys import sys
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_preprocessing from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model from official.resnet import resnet_model
...@@ -157,6 +157,7 @@ def parse_record(raw_record, is_training): ...@@ -157,6 +157,7 @@ def parse_record(raw_record, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False): num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
...@@ -199,6 +200,7 @@ def get_synth_input_fn(): ...@@ -199,6 +200,7 @@ def get_synth_input_fn():
# Running the model # Running the model
############################################################################### ###############################################################################
class ImagenetModel(resnet_model.Model): class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION): version=resnet_model.DEFAULT_VERSION):
...@@ -241,9 +243,20 @@ class ImagenetModel(resnet_model.Model): ...@@ -241,9 +243,20 @@ class ImagenetModel(resnet_model.Model):
def _get_block_sizes(resnet_size): def _get_block_sizes(resnet_size):
"""The number of block layers used for the Resnet model varies according """Retrieve the size of each block_layer in the ResNet model.
The number of block layers used for the Resnet model varies according
to the size of the model. This helper grabs the layer set we want, throwing to the size of the model. This helper grabs the layer set we want, throwing
an error if a non-standard size has been selected. an error if a non-standard size has been selected.
Args:
resnet_size: The number of convolutional layers needed in the model.
Returns:
A list of block sizes to use in building the model.
Raises:
KeyError: if invalid resnet_size is received.
""" """
choices = { choices = {
18: [2, 2, 2, 2], 18: [2, 2, 2, 2],
......
...@@ -204,8 +204,10 @@ def _aspect_preserving_resize(image, resize_min): ...@@ -204,8 +204,10 @@ def _aspect_preserving_resize(image, resize_min):
def _resize_image(image, height, width): def _resize_image(image, height, width):
"""Simple wrapper around tf.resize_images to make sure we use the same """Simple wrapper around tf.resize_images.
`method` and other details each time.
This is primarily to make sure we use the same `ResizeMethod` and other
details each time.
Args: Args:
image: A 3-D image `Tensor`. image: A 3-D image `Tensor`.
...@@ -220,6 +222,7 @@ def _resize_image(image, height, width): ...@@ -220,6 +222,7 @@ def _resize_image(image, height, width):
image, [height, width], method=tf.image.ResizeMethod.BILINEAR, image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
align_corners=False) align_corners=False)
def preprocess_image(image_buffer, bbox, output_height, output_width, def preprocess_image(image_buffer, bbox, output_height, output_width,
num_channels, is_training=False): num_channels, is_training=False):
"""Preprocesses the given image. """Preprocesses the given image.
......
...@@ -19,7 +19,7 @@ from __future__ import print_function ...@@ -19,7 +19,7 @@ from __future__ import print_function
import unittest import unittest
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.utils.testing import integration from official.utils.testing import integration
...@@ -39,9 +39,8 @@ class BaseTest(tf.test.TestCase): ...@@ -39,9 +39,8 @@ class BaseTest(tf.test.TestCase):
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape): def reshape(shape):
"""Returns the expected dimensions depending on if a """Returns the expected dimensions depending on if a GPU is being used."""
GPU is being used.
"""
# If a GPU is used for the test, the shape is returned (already in NCHW # If a GPU is used for the test, the shape is returned (already in NCHW
# form). When GPU is not used, the shape is converted to NHWC. # form). When GPU is not used, the shape is converted to NHWC.
if with_gpu: if with_gpu:
...@@ -240,8 +239,9 @@ class BaseTest(tf.test.TestCase): ...@@ -240,8 +239,9 @@ class BaseTest(tf.test.TestCase):
num_classes = 246 num_classes = 246
for version in (1, 2): for version in (1, 2):
model = imagenet_main.ImagenetModel(50, data_format='channels_last', model = imagenet_main.ImagenetModel(
num_classes=num_classes, version=version) 50, data_format='channels_last', num_classes=num_classes,
version=version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3]) fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True) output = model(fake_input, training=True)
...@@ -285,4 +285,3 @@ class BaseTest(tf.test.TestCase): ...@@ -285,4 +285,3 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -99,7 +99,8 @@ def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format): ...@@ -99,7 +99,8 @@ def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
################################################################################ ################################################################################
def _building_block_v1(inputs, filters, training, projection_shortcut, strides, def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
data_format): data_format):
""" """A single block for ResNet v1, without a bottleneck.
Convolution then batch normalization then ReLU as described by: Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf https://arxiv.org/pdf/1512.03385.pdf
...@@ -118,7 +119,7 @@ def _building_block_v1(inputs, filters, training, projection_shortcut, strides, ...@@ -118,7 +119,7 @@ def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
data_format: The input format ('channels_last' or 'channels_first'). data_format: The input format ('channels_last' or 'channels_first').
Returns: Returns:
The output tensor of the block. The output tensor of the block; shape should match inputs.
""" """
shortcut = inputs shortcut = inputs
...@@ -145,7 +146,8 @@ def _building_block_v1(inputs, filters, training, projection_shortcut, strides, ...@@ -145,7 +146,8 @@ def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
def _building_block_v2(inputs, filters, training, projection_shortcut, strides, def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
data_format): data_format):
""" """A single block for ResNet v2, without a bottleneck.
Batch normalization then ReLu then convolution as described by: Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf https://arxiv.org/pdf/1603.05027.pdf
...@@ -164,7 +166,7 @@ def _building_block_v2(inputs, filters, training, projection_shortcut, strides, ...@@ -164,7 +166,7 @@ def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
data_format: The input format ('channels_last' or 'channels_first'). data_format: The input format ('channels_last' or 'channels_first').
Returns: Returns:
The output tensor of the block. The output tensor of the block; shape should match inputs.
""" """
shortcut = inputs shortcut = inputs
inputs = batch_norm(inputs, training, data_format) inputs = batch_norm(inputs, training, data_format)
...@@ -190,13 +192,29 @@ def _building_block_v2(inputs, filters, training, projection_shortcut, strides, ...@@ -190,13 +192,29 @@ def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,
strides, data_format): strides, data_format):
""" """A single block for ResNet v1, with a bottleneck.
Similar to _building_block_v1(), except using the "bottleneck" blocks Similar to _building_block_v1(), except using the "bottleneck" blocks
described in: described in:
Convolution then batch normalization then ReLU as described by: Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
""" """
shortcut = inputs shortcut = inputs
...@@ -229,7 +247,8 @@ def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, ...@@ -229,7 +247,8 @@ def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,
def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
strides, data_format): strides, data_format):
""" """A single block for ResNet v2, without a bottleneck.
Similar to _building_block_v2(), except using the "bottleneck" blocks Similar to _building_block_v2(), except using the "bottleneck" blocks
described in: described in:
Convolution then batch normalization then ReLU as described by: Convolution then batch normalization then ReLU as described by:
...@@ -237,11 +256,26 @@ def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, ...@@ -237,11 +256,26 @@ def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
https://arxiv.org/pdf/1512.03385.pdf https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
adapted to the ordering conventions of: Adapted to the ordering conventions of:
Batch normalization then ReLu then convolution as described by: Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf https://arxiv.org/pdf/1603.05027.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
""" """
shortcut = inputs shortcut = inputs
inputs = batch_norm(inputs, training, data_format) inputs = batch_norm(inputs, training, data_format)
...@@ -313,8 +347,7 @@ def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, ...@@ -313,8 +347,7 @@ def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides,
class Model(object): class Model(object):
"""Base class for building the Resnet Model. """Base class for building the Resnet Model."""
"""
def __init__(self, resnet_size, bottleneck, num_classes, num_filters, def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
kernel_size, kernel_size,
...@@ -348,6 +381,9 @@ class Model(object): ...@@ -348,6 +381,9 @@ class Model(object):
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
Raises:
ValueError: if invalid version is selected.
""" """
self.resnet_size = resnet_size self.resnet_size = resnet_size
...@@ -358,7 +394,7 @@ class Model(object): ...@@ -358,7 +394,7 @@ class Model(object):
self.resnet_version = version self.resnet_version = version
if version not in (1, 2): if version not in (1, 2):
raise ValueError( raise ValueError(
"Resnet version should be 1 or 2. See README for citations.") 'Resnet version should be 1 or 2. See README for citations.')
self.bottleneck = bottleneck self.bottleneck = bottleneck
if bottleneck: if bottleneck:
...@@ -435,4 +471,3 @@ class Model(object): ...@@ -435,4 +471,3 @@ class Model(object):
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
...@@ -26,11 +26,11 @@ from __future__ import print_function ...@@ -26,11 +26,11 @@ from __future__ import print_function
import argparse import argparse
import os import os
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper
################################################################################ ################################################################################
...@@ -39,8 +39,7 @@ from official.resnet import resnet_model ...@@ -39,8 +39,7 @@ from official.resnet import resnet_model
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1, parse_record_fn, num_epochs=1, num_parallel_calls=1,
examples_per_epoch=0, multi_gpu=False): examples_per_epoch=0, multi_gpu=False):
"""Given a Dataset with raw records, parse each record into images and labels, """Given a Dataset with raw records, return an iterator over the records.
and return an iterator over the records.
Args: Args:
dataset: A Dataset representing raw records dataset: A Dataset representing raw records
...@@ -121,7 +120,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -121,7 +120,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args): def input_fn(is_training, data_dir, batch_size, *args): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32) images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32) labels = tf.zeros((batch_size, num_classes), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat() return tf.data.Dataset.from_tensors((images, labels)).repeat()
...@@ -231,9 +230,9 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -231,9 +230,9 @@ def resnet_model_fn(features, labels, mode, model_class,
# If no loss_filter_fn is passed, assume we want the default behavior, # If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss. # which is that batch_normalization variables are excluded from loss.
if not loss_filter_fn: def exclude_batch_norm(name):
def loss_filter_fn(name):
return 'batch_normalization' not in name return 'batch_normalization' not in name
loss_filter_fn = loss_filter_fn or exclude_batch_norm
# Add weight decay to the loss. # Add weight decay to the loss.
loss = cross_entropy + weight_decay * tf.add_n( loss = cross_entropy + weight_decay * tf.add_n(
...@@ -279,14 +278,19 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -279,14 +278,19 @@ def resnet_model_fn(features, labels, mode, model_class,
def validate_batch_size_for_multi_gpu(batch_size): def validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of """For multi-gpu, batch-size must be a multiple of the number of GPUs.
available GPUs.
Note that this should eventually be handled by replicate_model_fn Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however, directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place. so doing the work here until that feature is in place.
Args:
batch_size: the number of examples processed in each training batch.
Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid.
""" """
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
local_device_protos = device_lib.list_local_devices() local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU']) num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
...@@ -304,6 +308,8 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -304,6 +308,8 @@ def validate_batch_size_for_multi_gpu(batch_size):
def resnet_main(flags, model_function, input_function): def resnet_main(flags, model_function, input_function):
"""Shared main loop for ResNet Models."""
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
...@@ -340,8 +346,8 @@ def resnet_main(flags, model_function, input_function): ...@@ -340,8 +346,8 @@ def resnet_main(flags, model_function, input_function):
}) })
for _ in range(flags.train_epochs // flags.epochs_between_evals): for _ in range(flags.train_epochs // flags.epochs_between_evals):
train_hooks = hooks_helper.get_train_hooks(flags.hooks, train_hooks = hooks_helper.get_train_hooks(
batch_size=flags.batch_size) flags.hooks, batch_size=flags.batch_size)
print('Starting a training cycle.') print('Starting a training cycle.')
...@@ -384,7 +390,7 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -384,7 +390,7 @@ class ResnetArgParser(argparse.ArgumentParser):
self.add_argument( self.add_argument(
'--version', '-v', type=int, choices=[1, 2], '--version', '-v', type=int, choices=[1, 2],
default=resnet_model.DEFAULT_VERSION, default=resnet_model.DEFAULT_VERSION,
help="Version of ResNet. (1 or 2) See README.md for details." help='Version of ResNet. (1 or 2) See README.md for details.'
) )
self.add_argument( self.add_argument(
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
...@@ -46,11 +46,11 @@ Notes about add_argument(): ...@@ -46,11 +46,11 @@ Notes about add_argument():
The metavar variable determines how the flag will appear in help text. If The metavar variable determines how the flag will appear in help text. If
not specified, the convention is to use name.upper(). Thus rather than: not specified, the convention is to use name.upper(). Thus rather than:
--application_specific_arg APPLICATION_SPECIFIC_ARG, -asa APPLICATION_SPECIFIC_ARG --app_specific_arg APP_SPECIFIC_ARG, -asa APP_SPECIFIC_ARG
if metavar="<ASA>" is set, the user sees: if metavar="<ASA>" is set, the user sees:
--application_specific_arg <ASA>, -asa <ASA> --app_specific_arg <ASA>, -asa <ASA>
""" """
...@@ -216,7 +216,7 @@ class ImageModelParser(argparse.ArgumentParser): ...@@ -216,7 +216,7 @@ class ImageModelParser(argparse.ArgumentParser):
self.add_argument( self.add_argument(
"--data_format", "-df", "--data_format", "-df",
default=None, default=None,
choices=['channels_first', 'channels_last'], choices=["channels_first", "channels_last"],
help="A flag to override the data format used in the model. " help="A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not " "channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data " "always compatible with CPU. If left unspecified, the data "
......
...@@ -24,7 +24,7 @@ from __future__ import absolute_import ...@@ -24,7 +24,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks from official.utils.logging import hooks
...@@ -40,7 +40,7 @@ def get_train_hooks(name_list, **kwargs): ...@@ -40,7 +40,7 @@ def get_train_hooks(name_list, **kwargs):
name_list: a list of strings to name desired hook classes. Allowed: name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS as keys in HOOKS
kwargs: a dictionary of arguments to the hooks. **kwargs: a dictionary of arguments to the hooks.
Returns: Returns:
list of instantiated hooks, ready to be used in a classifier.train call. list of instantiated hooks, ready to be used in a classifier.train call.
...@@ -71,7 +71,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # ...@@ -71,7 +71,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
steps taken on the current worker. steps taken on the current worker.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default. names. If not set, log _TENSORS_TO_LOG by default.
kwargs: a dictionary of arguments to LoggingTensorHook. **kwargs: a dictionary of arguments to LoggingTensorHook.
Returns: Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be Returns a LoggingTensorHook with a standard set of tensors that will be
...@@ -90,7 +90,7 @@ def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argu ...@@ -90,7 +90,7 @@ def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argu
Args: Args:
save_steps: `int`, print profile traces every N steps. save_steps: `int`, print profile traces every N steps.
kwargs: a dictionary of arguments to ProfilerHook. **kwargs: a dictionary of arguments to ProfilerHook.
Returns: Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into Returns a ProfilerHook that writes out timelines that can be loaded into
...@@ -111,7 +111,7 @@ def get_examples_per_second_hook(every_n_steps=100, ...@@ -111,7 +111,7 @@ def get_examples_per_second_hook(every_n_steps=100,
batch_size: `int`, total batch size used to calculate examples/second from batch_size: `int`, total batch size used to calculate examples/second from
global time. global time.
warm_steps: skip this number of steps before logging and running average. warm_steps: skip this number of steps before logging and running average.
kwargs: a dictionary of arguments to ExamplesPerSecondHook. **kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Returns: Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into Returns a ProfilerHook that writes out timelines that can be loaded into
...@@ -128,4 +128,3 @@ HOOKS = { ...@@ -128,4 +128,3 @@ HOOKS = {
'profilerhook': get_profiler_hook, 'profilerhook': get_profiler_hook,
'examplespersecondhook': get_examples_per_second_hook, 'examplespersecondhook': get_examples_per_second_hook,
} }
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import unittest import unittest
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper from official.utils.logging import hooks_helper
......
...@@ -21,9 +21,9 @@ from __future__ import print_function ...@@ -21,9 +21,9 @@ from __future__ import print_function
import time import time
import tensorflow as tf import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session # pylint: disable=g-bad-import-order
from tensorflow.python.training import monitored_session
from official.utils.logging import hooks from official.utils.logging import hooks
...@@ -31,6 +31,7 @@ tf.logging.set_verbosity(tf.logging.ERROR) ...@@ -31,6 +31,7 @@ tf.logging.set_verbosity(tf.logging.ERROR)
class ExamplesPerSecondHookTest(tf.test.TestCase): class ExamplesPerSecondHookTest(tf.test.TestCase):
"""Tests for the ExamplesPerSecondHook."""
def setUp(self): def setUp(self):
"""Mock out logging calls to verify if correct info is being monitored.""" """Mock out logging calls to verify if correct info is being monitored."""
...@@ -71,7 +72,7 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -71,7 +72,7 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_steps=every_n_steps, every_n_steps=every_n_steps,
warm_steps=warm_steps) warm_steps=warm_steps)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.logged_message = '' self.logged_message = ''
...@@ -120,7 +121,7 @@ class ExamplesPerSecondHookTest(tf.test.TestCase): ...@@ -120,7 +121,7 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_steps=None, every_n_steps=None,
every_n_secs=every_n_secs) every_n_secs=every_n_secs)
hook.begin() hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.logged_message = '' self.logged_message = ''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment