Commit 3b158095 authored by Ilya Mironov's avatar Ilya Mironov
Browse files

Merge branch 'master' of https://github.com/ilyamironov/models

parents a90db800 be659c2f
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
/research/lexnet_nc/ @vered1986 @waterson /research/lexnet_nc/ @vered1986 @waterson
/research/lfads/ @jazcollins @susillo /research/lfads/ @jazcollins @susillo
/research/lm_1b/ @oriolvinyals @panyx0718 /research/lm_1b/ @oriolvinyals @panyx0718
/research/marco/ @vincentvanhoucke
/research/maskgan/ @a-dai /research/maskgan/ @a-dai
/research/namignizer/ @knathanieltucker /research/namignizer/ @knathanieltucker
/research/neural_gpu/ @lukaszkaiser /research/neural_gpu/ @lukaszkaiser
......
...@@ -98,41 +98,41 @@ ...@@ -98,41 +98,41 @@
"type": "RECORD" "type": "RECORD"
}, },
{ {
"description": "The list of hyperparameters of the model.", "description": "The list of parameters run with the model. It could contain hyperparameters or others.",
"fields": [ "fields": [
{ {
"description": "The name of the hyperparameter.", "description": "The name of the parameter.",
"mode": "REQUIRED", "mode": "REQUIRED",
"name": "name", "name": "name",
"type": "STRING" "type": "STRING"
}, },
{ {
"description": "The string value of the hyperparameter.", "description": "The string value of the parameter.",
"mode": "NULLABLE", "mode": "NULLABLE",
"name": "string_value", "name": "string_value",
"type": "STRING" "type": "STRING"
}, },
{ {
"description": "The bool value of the hyperparameter.", "description": "The bool value of the parameter.",
"mode": "NULLABLE", "mode": "NULLABLE",
"name": "bool_value", "name": "bool_value",
"type": "STRING" "type": "STRING"
}, },
{ {
"description": "The int/long value of the hyperparameter.", "description": "The int/long value of the parameter.",
"mode": "NULLABLE", "mode": "NULLABLE",
"name": "long_value", "name": "long_value",
"type": "INTEGER" "type": "INTEGER"
}, },
{ {
"description": "The double/float value of hyperparameter.", "description": "The double/float value of parameter.",
"mode": "NULLABLE", "mode": "NULLABLE",
"name": "float_value", "name": "float_value",
"type": "FLOAT" "type": "FLOAT"
} }
], ],
"mode": "REPEATED", "mode": "REPEATED",
"name": "hyperparameter", "name": "run_parameters",
"type": "RECORD" "type": "RECORD"
}, },
{ {
......
...@@ -17,16 +17,16 @@ from __future__ import absolute_import ...@@ -17,16 +17,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse from absl import app as absl_app
import sys from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
...@@ -62,7 +62,9 @@ def create_model(data_format): ...@@ -62,7 +62,9 @@ def create_model(data_format):
# (a subclass of tf.keras.Model) makes for a compact description. # (a subclass of tf.keras.Model) makes for a compact description.
return tf.keras.Sequential( return tf.keras.Sequential(
[ [
l.Reshape(input_shape), l.Reshape(
target_shape=input_shape,
input_shape=(28 * 28,)),
l.Conv2D( l.Conv2D(
32, 32,
5, 5,
...@@ -84,6 +86,16 @@ def create_model(data_format): ...@@ -84,6 +86,16 @@ def create_model(data_format):
]) ])
def define_mnist_flags():
flags_core.define_base(multi_gpu=True, num_gpu=False)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator.""" """The model_fn argument for creating an Estimator."""
model = create_model(params['data_format']) model = create_model(params['data_format'])
...@@ -170,14 +182,17 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -170,14 +182,17 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err) raise ValueError(err)
def main(argv): def run_mnist(flags_obj):
parser = MNISTArgParser() """Run MNIST training and eval loop.
flags = parser.parse_args(args=argv[1:])
Args:
flags_obj: An object containing parsed flag values.
"""
model_function = model_fn model_function = model_fn
if flags.multi_gpu: if flags_obj.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size) validate_batch_size_for_multi_gpu(flags_obj.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn, # There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens # and (2) wrap the optimizer. The first happens here, and (2) happens
...@@ -185,16 +200,16 @@ def main(argv): ...@@ -185,16 +200,16 @@ def main(argv):
model_function = tf.contrib.estimator.replicate_model_fn( model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN) model_fn, loss_reduction=tf.losses.Reduction.MEAN)
data_format = flags.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last') if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=model_function, model_fn=model_function,
model_dir=flags.model_dir, model_dir=flags_obj.model_dir,
params={ params={
'data_format': data_format, 'data_format': data_format,
'multi_gpu': flags.multi_gpu 'multi_gpu': flags_obj.multi_gpu
}) })
# Set up training and evaluation input functions. # Set up training and evaluation input functions.
...@@ -204,57 +219,46 @@ def main(argv): ...@@ -204,57 +219,46 @@ def main(argv):
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small # randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch. # enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags.data_dir) ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size) ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times # Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session. # during each training session.
ds = ds.repeat(flags.epochs_between_evals) ds = ds.repeat(flags_obj.epochs_between_evals)
return ds return ds
def eval_input_fn(): def eval_input_fn():
return dataset.test(flags.data_dir).batch( return dataset.test(flags_obj.data_dir).batch(
flags.batch_size).make_one_shot_iterator().get_next() flags_obj.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps. # Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size) flags_obj.hooks, batch_size=flags_obj.batch_size)
# Train and evaluate model. # Train and evaluate model.
for _ in range(flags.train_epochs // flags.epochs_between_evals): for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(flags.stop_threshold, if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']): eval_results['accuracy']):
break break
# Export the model # Export the model
if flags.export_dir is not None: if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image, 'image': image,
}) })
mnist_classifier.export_savedmodel(flags.export_dir, input_fn) mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model."""
def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser(),
])
self.set_defaults( def main(_):
data_dir='/tmp/mnist_data', run_mnist(flags.FLAGS)
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv) define_mnist_flags()
absl_app.run(main)
...@@ -26,17 +26,19 @@ from __future__ import absolute_import ...@@ -26,17 +26,19 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
import time import time
import tensorflow as tf # pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order from absl import app as absl_app
from absl import flags
import tensorflow as tf
import tensorflow.contrib.eager as tfe
# pylint: enable=g-bad-import-order
from official.mnist import dataset as mnist_dataset from official.mnist import dataset as mnist_dataset
from official.mnist import mnist from official.mnist import mnist
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core
def loss(logits, labels): def loss(logits, labels):
...@@ -63,7 +65,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None): ...@@ -63,7 +65,7 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
# Record the operations used to compute the loss given the input, # Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables # so that the gradient of the loss with respect to the variables
# can be computed. # can be computed.
with tfe.GradientTape() as tape: with tf.GradientTape() as tape:
logits = model(images, training=True) logits = model(images, training=True)
loss_value = loss(logits, labels) loss_value = loss(logits, labels)
tf.contrib.summary.scalar('loss', loss_value) tf.contrib.summary.scalar('loss', loss_value)
...@@ -95,38 +97,41 @@ def test(model, dataset): ...@@ -95,38 +97,41 @@ def test(model, dataset):
tf.contrib.summary.scalar('accuracy', accuracy.result()) tf.contrib.summary.scalar('accuracy', accuracy.result())
def main(argv): def run_mnist_eager(flags_obj):
parser = MNISTEagerArgParser() """Run MNIST training and eval loop in eager mode.
flags = parser.parse_args(args=argv[1:])
tfe.enable_eager_execution() Args:
flags_obj: An object containing parsed flag values.
"""
tf.enable_eager_execution()
# Automatically determine device and data_format # Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first') (device, data_format) = ('/gpu:0', 'channels_first')
if flags.no_gpu or tfe.num_gpus() <= 0: if flags_obj.no_gpu or tf.test.is_gpu_available():
(device, data_format) = ('/cpu:0', 'channels_last') (device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value. # If data_format is defined in FLAGS, overwrite automatically set value.
if flags.data_format is not None: if flags_obj.data_format is not None:
data_format = flags.data_format data_format = flags_obj.data_format
print('Using device %s, and data format %s.' % (device, data_format)) print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets # Load the datasets
train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch( train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
flags.batch_size) flags_obj.batch_size)
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size) test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.create_model(data_format) model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum) optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)
# Create file writers for writing TensorBoard summaries. # Create file writers for writing TensorBoard summaries.
if flags.output_dir: if flags_obj.output_dir:
# Create directories to which summaries will be written # Create directories to which summaries will be written
# tensorboard --logdir=<output_dir> # tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries. # can then be used to see the recorded summaries.
train_dir = os.path.join(flags.output_dir, 'train') train_dir = os.path.join(flags_obj.output_dir, 'train')
test_dir = os.path.join(flags.output_dir, 'eval') test_dir = os.path.join(flags_obj.output_dir, 'eval')
tf.gfile.MakeDirs(flags.output_dir) tf.gfile.MakeDirs(flags_obj.output_dir)
else: else:
train_dir = None train_dir = None
test_dir = None test_dir = None
...@@ -136,19 +141,20 @@ def main(argv): ...@@ -136,19 +141,20 @@ def main(argv):
test_dir, flush_millis=10000, name='test') test_dir, flush_millis=10000, name='test')
# Create and restore checkpoint (if one exists on the path) # Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt') checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step() step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint( checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter) model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists. # Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir)) checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))
# Train and evaluate for a set number of epochs. # Train and evaluate for a set number of epochs.
with tf.device(device): with tf.device(device):
for _ in range(flags.train_epochs): for _ in range(flags_obj.train_epochs):
start = time.time() start = time.time()
with summary_writer.as_default(): with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter, flags.log_interval) train(model, optimizer, train_ds, step_counter,
flags_obj.log_interval)
end = time.time() end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' % print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1, (checkpoint.save_counter.numpy() + 1,
...@@ -159,50 +165,42 @@ def main(argv): ...@@ -159,50 +165,42 @@ def main(argv):
checkpoint.save(checkpoint_prefix) checkpoint.save(checkpoint_prefix)
class MNISTEagerArgParser(argparse.ArgumentParser): def define_mnist_eager_flags():
"""Argument parser for running MNIST model with eager training loop.""" """Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager()
def __init__(self): flags_core.define_image()
super(MNISTEagerArgParser, self).__init__(parents=[ flags.adopt_module_key_flags(flags_core)
parsers.EagerParser(),
parsers.ImageModelParser()]) flags.DEFINE_integer(
name='log_interval', short_name='li', default=10,
self.add_argument( help=flags_core.help_wrap('batches between logging training status'))
'--log_interval', '-li',
type=int, flags.DEFINE_string(
default=10, name='output_dir', short_name='od', default=None,
metavar='N', help=flags_core.help_wrap('Directory to write TensorBoard summaries'))
help='[default: %(default)s] batches between logging training status')
self.add_argument( flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
'--output_dir', '-od', help=flags_core.help_wrap('Learning rate.'))
type=str,
default=None, flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
metavar='<OD>', help=flags_core.help_wrap('SGD momentum.'))
help='[default: %(default)s] Directory to write TensorBoard summaries')
self.add_argument( flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
'--lr', '-lr', help=flags_core.help_wrap(
type=float, 'disables GPU usage even if a GPU is available'))
default=0.01,
metavar='<LR>', flags_core.set_defaults(
help='[default: %(default)s] learning rate') data_dir='/tmp/tensorflow/mnist/input_data',
self.add_argument( model_dir='/tmp/tensorflow/mnist/checkpoints/',
'--momentum', '-m', batch_size=100,
type=float, train_epochs=10,
default=0.5, )
metavar='<M>',
help='[default: %(default)s] SGD momentum')
self.add_argument( def main(_):
'--no_gpu', '-nogpu', run_mnist_eager(flags.FLAGS)
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')
self.set_defaults(
data_dir='/tmp/tensorflow/mnist/input_data',
model_dir='/tmp/tensorflow/mnist/checkpoints/',
batch_size=100,
train_epochs=10,
)
if __name__ == '__main__': if __name__ == '__main__':
main(argv=sys.argv) define_mnist_eager_flags()
absl_app.run(main=main)
...@@ -23,10 +23,16 @@ from __future__ import absolute_import ...@@ -23,10 +23,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import sys
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset # For open source environment, add grandparent directory for import
from official.mnist import mnist sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.path[0]))))
from official.mnist import dataset # pylint: disable=wrong-import-position
from official.mnist import mnist # pylint: disable=wrong-import-position
# Cloud TPU Cluster Resolver flags # Cloud TPU Cluster Resolver flags
tf.flags.DEFINE_string( tf.flags.DEFINE_string(
......
...@@ -59,3 +59,13 @@ Other versions and formats: ...@@ -59,3 +59,13 @@ Other versions and formats:
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz) * [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz) * [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz) * [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.
num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
...@@ -19,10 +19,12 @@ from __future__ import division ...@@ -19,10 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -40,6 +42,8 @@ _NUM_IMAGES = { ...@@ -40,6 +42,8 @@ _NUM_IMAGES = {
'validation': 10000, 'validation': 10000,
} }
DATASET_NAME = 'CIFAR-10'
############################################################################### ###############################################################################
# Data processing # Data processing
...@@ -103,8 +107,7 @@ def preprocess_image(image, is_training): ...@@ -103,8 +107,7 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -112,12 +115,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -112,12 +115,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -125,12 +122,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -125,12 +122,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'], dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls, parse_record, num_epochs,
examples_per_epoch=num_images, multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -145,7 +140,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -145,7 +140,7 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data.""" """Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
...@@ -155,8 +150,8 @@ class Cifar10Model(resnet_model.Model): ...@@ -155,8 +150,8 @@ class Cifar10Model(resnet_model.Model):
data format to use when setting up the model. data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network
See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations. dtype: The TensorFlow dtype to use for calculations.
Raises: Raises:
...@@ -176,12 +171,10 @@ class Cifar10Model(resnet_model.Model): ...@@ -176,12 +171,10 @@ class Cifar10Model(resnet_model.Model):
conv_stride=1, conv_stride=1,
first_pool_size=None, first_pool_size=None,
first_pool_stride=None, first_pool_stride=None,
second_pool_size=8,
second_pool_stride=1,
block_sizes=[num_blocks] * 3, block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2], block_strides=[1, 2, 2],
final_size=64, final_size=64,
version=version, resnet_version=resnet_version,
data_format=data_format, data_format=data_format,
dtype=dtype dtype=dtype
) )
...@@ -218,33 +211,43 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -218,33 +211,43 @@ def cifar10_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
version=params['version'], resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn, loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
def main(argv): def define_cifar_flags():
parser = resnet_run_loop.ResnetArgParser() resnet_run_loop.define_resnet_flags()
# Set defaults that are reasonable for this model. flags.adopt_module_key_flags(resnet_run_loop)
parser.set_defaults(data_dir='/tmp/cifar10_data', flags_core.set_defaults(data_dir='/tmp/cifar10_data',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
resnet_size=32, resnet_size='32',
train_epochs=250, train_epochs=250,
epochs_between_evals=10, epochs_between_evals=10,
batch_size=128) batch_size=128)
flags = parser.parse_args(args=argv[1:]) def run_cifar(flags_obj):
"""Run ResNet CIFAR-10 training and eval loop.
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags, cifar10_model_fn, input_function, flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
def main(_):
run_cifar(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv) define_cifar_flags()
absl_app.run(main)
...@@ -37,6 +37,11 @@ class BaseTest(tf.test.TestCase): ...@@ -37,6 +37,11 @@ class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet. """Tests for the Cifar10 version of Resnet.
""" """
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
cifar10_main.define_cifar_flags()
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
...@@ -71,115 +76,92 @@ class BaseTest(tf.test.TestCase): ...@@ -71,115 +76,92 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False): def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
with tf.Graph().as_default() as g: input_fn = cifar10_main.get_synth_input_fn()
input_fn = cifar10_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE)
dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next()
features, labels = iterator.get_next() spec = cifar10_main.cifar10_model_fn(
spec = cifar10_main.cifar10_model_fn( features, labels, mode, {
features, labels, mode, { 'dtype': dtype,
'dtype': dtype, 'resnet_size': 32,
'resnet_size': 32, 'data_format': 'channels_last',
'data_format': 'channels_last', 'batch_size': _BATCH_SIZE,
'batch_size': _BATCH_SIZE, 'resnet_version': resnet_version,
'version': version, 'loss_scale': 128 if dtype == tf.float16 else 1,
'loss_scale': 128 if dtype == tf.float16 else 1, })
'multi_gpu': multi_gpu
}) predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
predictions = spec.predictions (_BATCH_SIZE, 10))
self.assertAllEqual(predictions['probabilities'].shape, self.assertEqual(predictions['probabilities'].dtype, tf.float32)
(_BATCH_SIZE, 10)) self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(loss.shape, ())
loss = spec.loss self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32)
def test_cifar10_model_fn_trainmode__v2(self): def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32)
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
def test_cifar10_model_fn_eval_mode_v1(self): def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self): def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self): def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
resnet_version=1, dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
resnet_version=2, dtype=tf.float32)
def _test_cifar10model_shape(self, version): def _test_cifar10model_shape(self, resnet_version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
model = cifar10_main.Cifar10Model(32, data_format='channels_last', model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version) num_classes=num_classes,
resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS]) fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True) output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10model_shape_v1(self): def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1) self._test_cifar10model_shape(resnet_version=1)
def test_cifar10model_shape_v2(self): def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2) self._test_cifar10model_shape(resnet_version=2)
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1'] extra_flags=['-resnet_version', '1']
) )
def test_cifar10_end_to_end_synthetic_v2(self): def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2'] extra_flags=['-resnet_version', '2']
) )
......
...@@ -19,10 +19,12 @@ from __future__ import division ...@@ -19,10 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.resnet import imagenet_preprocessing from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -39,6 +41,7 @@ _NUM_IMAGES = { ...@@ -39,6 +41,7 @@ _NUM_IMAGES = {
_NUM_TRAIN_FILES = 1024 _NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 1500 _SHUFFLE_BUFFER = 1500
DATASET_NAME = 'ImageNet'
############################################################################### ###############################################################################
# Data processing # Data processing
...@@ -154,8 +157,7 @@ def parse_record(raw_record, is_training): ...@@ -154,8 +157,7 @@ def parse_record(raw_record, is_training):
return image, label return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -163,12 +165,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -163,12 +165,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -180,15 +176,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -180,15 +176,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Shuffle the input files # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
# Convert to individual records # Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record, dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs, num_parallel_calls, examples_per_epoch=num_images, num_epochs
multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -203,7 +197,7 @@ class ImagenetModel(resnet_model.Model): ...@@ -203,7 +197,7 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data.""" """Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
...@@ -213,8 +207,8 @@ class ImagenetModel(resnet_model.Model): ...@@ -213,8 +207,8 @@ class ImagenetModel(resnet_model.Model):
data format to use when setting up the model. data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network
See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations. dtype: The TensorFlow dtype to use for calculations.
""" """
...@@ -235,12 +229,10 @@ class ImagenetModel(resnet_model.Model): ...@@ -235,12 +229,10 @@ class ImagenetModel(resnet_model.Model):
conv_stride=2, conv_stride=2,
first_pool_size=3, first_pool_size=3,
first_pool_stride=2, first_pool_stride=2,
second_pool_size=7,
second_pool_stride=1,
block_sizes=_get_block_sizes(resnet_size), block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2], block_strides=[1, 2, 2, 2],
final_size=final_size, final_size=final_size,
version=version, resnet_version=resnet_version,
data_format=data_format, data_format=data_format,
dtype=dtype dtype=dtype
) )
...@@ -297,31 +289,39 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -297,31 +289,39 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
version=params['version'], resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=None, loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
def main(argv): def define_imagenet_flags():
parser = resnet_run_loop.ResnetArgParser( resnet_run_loop.define_resnet_flags(
resnet_size_choices=[18, 34, 50, 101, 152, 200]) resnet_size_choices=['18', '34', '50', '101', '152', '200'])
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=100)
parser.set_defaults(
train_epochs=100
)
flags = parser.parse_args(args=argv[1:]) def run_imagenet(flags_obj):
"""Run ResNet ImageNet training and eval loop.
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags, imagenet_model_fn, input_function, flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS]) shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
def main(_):
run_imagenet(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv) define_imagenet_flags()
absl_app.run(main)
...@@ -32,11 +32,16 @@ _LABEL_CLASSES = 1001 ...@@ -32,11 +32,16 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
imagenet_main.define_imagenet_flags()
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
def _tensor_shapes_helper(self, resnet_size, version, dtype, with_gpu): def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape): def reshape(shape):
"""Returns the expected dimensions depending on if a GPU is being used.""" """Returns the expected dimensions depending on if a GPU is being used."""
...@@ -54,7 +59,7 @@ class BaseTest(tf.test.TestCase): ...@@ -54,7 +59,7 @@ class BaseTest(tf.test.TestCase):
model = imagenet_main.ImagenetModel( model = imagenet_main.ImagenetModel(
resnet_size=resnet_size, resnet_size=resnet_size,
data_format='channels_first' if with_gpu else 'channels_last', data_format='channels_first' if with_gpu else 'channels_last',
version=version, resnet_version=resnet_version,
dtype=dtype dtype=dtype
) )
inputs = tf.random_uniform([1, 224, 224, 3]) inputs = tf.random_uniform([1, 224, 224, 3])
...@@ -90,186 +95,166 @@ class BaseTest(tf.test.TestCase): ...@@ -90,186 +95,166 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): def tensor_shapes_helper(self, resnet_size, resnet_version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size, version=version, self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float32, with_gpu=with_gpu) dtype=tf.float32, with_gpu=with_gpu)
self._tensor_shapes_helper(resnet_size=resnet_size, version=version, self._tensor_shapes_helper(resnet_size=resnet_size,
resnet_version=resnet_version,
dtype=tf.float16, with_gpu=with_gpu) dtype=tf.float16, with_gpu=with_gpu)
def test_tensor_shapes_resnet_18_v1(self): def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, version=1) self.tensor_shapes_helper(18, resnet_version=1)
def test_tensor_shapes_resnet_18_v2(self): def test_tensor_shapes_resnet_18_v2(self):
self.tensor_shapes_helper(18, version=2) self.tensor_shapes_helper(18, resnet_version=2)
def test_tensor_shapes_resnet_34_v1(self): def test_tensor_shapes_resnet_34_v1(self):
self.tensor_shapes_helper(34, version=1) self.tensor_shapes_helper(34, resnet_version=1)
def test_tensor_shapes_resnet_34_v2(self): def test_tensor_shapes_resnet_34_v2(self):
self.tensor_shapes_helper(34, version=2) self.tensor_shapes_helper(34, resnet_version=2)
def test_tensor_shapes_resnet_50_v1(self): def test_tensor_shapes_resnet_50_v1(self):
self.tensor_shapes_helper(50, version=1) self.tensor_shapes_helper(50, resnet_version=1)
def test_tensor_shapes_resnet_50_v2(self): def test_tensor_shapes_resnet_50_v2(self):
self.tensor_shapes_helper(50, version=2) self.tensor_shapes_helper(50, resnet_version=2)
def test_tensor_shapes_resnet_101_v1(self): def test_tensor_shapes_resnet_101_v1(self):
self.tensor_shapes_helper(101, version=1) self.tensor_shapes_helper(101, resnet_version=1)
def test_tensor_shapes_resnet_101_v2(self): def test_tensor_shapes_resnet_101_v2(self):
self.tensor_shapes_helper(101, version=2) self.tensor_shapes_helper(101, resnet_version=2)
def test_tensor_shapes_resnet_152_v1(self): def test_tensor_shapes_resnet_152_v1(self):
self.tensor_shapes_helper(152, version=1) self.tensor_shapes_helper(152, resnet_version=1)
def test_tensor_shapes_resnet_152_v2(self): def test_tensor_shapes_resnet_152_v2(self):
self.tensor_shapes_helper(152, version=2) self.tensor_shapes_helper(152, resnet_version=2)
def test_tensor_shapes_resnet_200_v1(self): def test_tensor_shapes_resnet_200_v1(self):
self.tensor_shapes_helper(200, version=1) self.tensor_shapes_helper(200, resnet_version=1)
def test_tensor_shapes_resnet_200_v2(self): def test_tensor_shapes_resnet_200_v2(self):
self.tensor_shapes_helper(200, version=2) self.tensor_shapes_helper(200, resnet_version=2)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v1(self): def test_tensor_shapes_resnet_18_with_gpu_v1(self):
self.tensor_shapes_helper(18, version=1, with_gpu=True) self.tensor_shapes_helper(18, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_18_with_gpu_v2(self): def test_tensor_shapes_resnet_18_with_gpu_v2(self):
self.tensor_shapes_helper(18, version=2, with_gpu=True) self.tensor_shapes_helper(18, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v1(self): def test_tensor_shapes_resnet_34_with_gpu_v1(self):
self.tensor_shapes_helper(34, version=1, with_gpu=True) self.tensor_shapes_helper(34, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_34_with_gpu_v2(self): def test_tensor_shapes_resnet_34_with_gpu_v2(self):
self.tensor_shapes_helper(34, version=2, with_gpu=True) self.tensor_shapes_helper(34, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v1(self): def test_tensor_shapes_resnet_50_with_gpu_v1(self):
self.tensor_shapes_helper(50, version=1, with_gpu=True) self.tensor_shapes_helper(50, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_50_with_gpu_v2(self): def test_tensor_shapes_resnet_50_with_gpu_v2(self):
self.tensor_shapes_helper(50, version=2, with_gpu=True) self.tensor_shapes_helper(50, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v1(self): def test_tensor_shapes_resnet_101_with_gpu_v1(self):
self.tensor_shapes_helper(101, version=1, with_gpu=True) self.tensor_shapes_helper(101, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_101_with_gpu_v2(self): def test_tensor_shapes_resnet_101_with_gpu_v2(self):
self.tensor_shapes_helper(101, version=2, with_gpu=True) self.tensor_shapes_helper(101, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v1(self): def test_tensor_shapes_resnet_152_with_gpu_v1(self):
self.tensor_shapes_helper(152, version=1, with_gpu=True) self.tensor_shapes_helper(152, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_152_with_gpu_v2(self): def test_tensor_shapes_resnet_152_with_gpu_v2(self):
self.tensor_shapes_helper(152, version=2, with_gpu=True) self.tensor_shapes_helper(152, resnet_version=2, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v1(self): def test_tensor_shapes_resnet_200_with_gpu_v1(self):
self.tensor_shapes_helper(200, version=1, with_gpu=True) self.tensor_shapes_helper(200, resnet_version=1, with_gpu=True)
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, resnet_version=2, with_gpu=True)
def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu): def resnet_model_fn_helper(self, mode, resnet_version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
with tf.Graph().as_default() as g: tf.train.create_global_step()
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn()
input_fn = imagenet_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE)
dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next()
features, labels = iterator.get_next() spec = imagenet_main.imagenet_model_fn(
spec = imagenet_main.imagenet_model_fn( features, labels, mode, {
features, labels, mode, { 'dtype': dtype,
'dtype': dtype, 'resnet_size': 50,
'resnet_size': 50, 'data_format': 'channels_last',
'data_format': 'channels_last', 'batch_size': _BATCH_SIZE,
'batch_size': _BATCH_SIZE, 'resnet_version': resnet_version,
'version': version, 'loss_scale': 128 if dtype == tf.float16 else 1,
'loss_scale': 128 if dtype == tf.float16 else 1, })
'multi_gpu': multi_gpu,
}) predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
predictions = spec.predictions (_BATCH_SIZE, _LABEL_CLASSES))
self.assertAllEqual(predictions['probabilities'].shape, self.assertEqual(predictions['probabilities'].dtype, tf.float32)
(_BATCH_SIZE, _LABEL_CLASSES)) self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(loss.shape, ())
loss = spec.loss self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'block_layer4:0',
'final_reduce_mean:0', 'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_train_mode_v2(self): def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32)
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
def test_resnet_model_fn_train_mode_multi_gpu_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True)
def test_resnet_model_fn_eval_mode_v1(self): def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self): def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self): def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=1,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=2,
dtype=tf.float32)
def _test_imagenetmodel_shape(self, version): def _test_imagenetmodel_shape(self, resnet_version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
model = imagenet_main.ImagenetModel( model = imagenet_main.ImagenetModel(
50, data_format='channels_last', num_classes=num_classes, 50, data_format='channels_last', num_classes=num_classes,
version=version) resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3]) fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True) output = model(fake_input, training=True)
...@@ -277,45 +262,45 @@ class BaseTest(tf.test.TestCase): ...@@ -277,45 +262,45 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenetmodel_shape_v1(self): def test_imagenetmodel_shape_v1(self):
self._test_imagenetmodel_shape(version=1) self._test_imagenetmodel_shape(resnet_version=1)
def test_imagenetmodel_shape_v2(self): def test_imagenetmodel_shape_v2(self):
self._test_imagenetmodel_shape(version=2) self._test_imagenetmodel_shape(resnet_version=2)
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1'] extra_flags=['-v', '1']
) )
def test_imagenet_end_to_end_synthetic_v2(self): def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2'] extra_flags=['-v', '2']
) )
def test_imagenet_end_to_end_synthetic_v1_tiny(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '18'] extra_flags=['-resnet_version', '1', '-resnet_size', '18']
) )
def test_imagenet_end_to_end_synthetic_v2_tiny(self): def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '18'] extra_flags=['-resnet_version', '2', '-resnet_size', '18']
) )
def test_imagenet_end_to_end_synthetic_v1_huge(self): def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '200'] extra_flags=['-resnet_version', '1', '-resnet_size', '200']
) )
def test_imagenet_end_to_end_synthetic_v2_huge(self): def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '200'] extra_flags=['-resnet_version', '2', '-resnet_size', '200']
) )
......
...@@ -41,14 +41,22 @@ from official.utils.testing import reference_data ...@@ -41,14 +41,22 @@ from official.utils.testing import reference_data
DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first
BATCH_SIZE = 32 BATCH_SIZE = 32
BLOCK_TESTS = [ BLOCK_TESTS = [
dict(bottleneck=True, projection=True, version=1, width=8, channels=4), dict(bottleneck=True, projection=True, resnet_version=1, width=8,
dict(bottleneck=True, projection=True, version=2, width=8, channels=4), channels=4),
dict(bottleneck=True, projection=False, version=1, width=8, channels=4), dict(bottleneck=True, projection=True, resnet_version=2, width=8,
dict(bottleneck=True, projection=False, version=2, width=8, channels=4), channels=4),
dict(bottleneck=False, projection=True, version=1, width=8, channels=4), dict(bottleneck=True, projection=False, resnet_version=1, width=8,
dict(bottleneck=False, projection=True, version=2, width=8, channels=4), channels=4),
dict(bottleneck=False, projection=False, version=1, width=8, channels=4), dict(bottleneck=True, projection=False, resnet_version=2, width=8,
dict(bottleneck=False, projection=False, version=2, width=8, channels=4), channels=4),
dict(bottleneck=False, projection=True, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=True, resnet_version=2, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=1, width=8,
channels=4),
dict(bottleneck=False, projection=False, resnet_version=2, width=8,
channels=4),
] ]
...@@ -95,7 +103,7 @@ class BaseTest(reference_data.BaseTest): ...@@ -95,7 +103,7 @@ class BaseTest(reference_data.BaseTest):
return projection_shortcut return projection_shortcut
def _resnet_block_ops(self, test, batch_size, bottleneck, projection, def _resnet_block_ops(self, test, batch_size, bottleneck, projection,
version, width, channels): resnet_version, width, channels):
"""Test whether resnet block construction has changed. """Test whether resnet block construction has changed.
Args: Args:
...@@ -104,7 +112,7 @@ class BaseTest(reference_data.BaseTest): ...@@ -104,7 +112,7 @@ class BaseTest(reference_data.BaseTest):
batch normalization. batch normalization.
bottleneck: Whether or not to use bottleneck layers. bottleneck: Whether or not to use bottleneck layers.
projection: Whether or not to project the input. projection: Whether or not to project the input.
version: Which version of ResNet to test. resnet_version: Which version of ResNet to test.
width: The width of the fake image. width: The width of the fake image.
channels: The number of channels in the fake image. channels: The number of channels in the fake image.
""" """
...@@ -113,12 +121,12 @@ class BaseTest(reference_data.BaseTest): ...@@ -113,12 +121,12 @@ class BaseTest(reference_data.BaseTest):
batch_size, batch_size,
"bottleneck" if bottleneck else "building", "bottleneck" if bottleneck else "building",
"_projection" if projection else "", "_projection" if projection else "",
version, resnet_version,
width, width,
channels channels
) )
if version == 1: if resnet_version == 1:
block_fn = resnet_model._building_block_v1 block_fn = resnet_model._building_block_v1
if bottleneck: if bottleneck:
block_fn = resnet_model._bottleneck_block_v1 block_fn = resnet_model._bottleneck_block_v1
......
...@@ -353,8 +353,8 @@ class Model(object): ...@@ -353,8 +353,8 @@ class Model(object):
def __init__(self, resnet_size, bottleneck, num_classes, num_filters, def __init__(self, resnet_size, bottleneck, num_classes, num_filters,
kernel_size, kernel_size,
conv_stride, first_pool_size, first_pool_stride, conv_stride, first_pool_size, first_pool_stride,
second_pool_size, second_pool_stride, block_sizes, block_strides, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None, final_size, resnet_version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE): dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image. """Creates a model for classifying an image.
...@@ -371,16 +371,14 @@ class Model(object): ...@@ -371,16 +371,14 @@ class Model(object):
If none, the first pooling layer is skipped. If none, the first pooling layer is skipped.
first_pool_stride: stride size for the first pooling layer. Not used first_pool_stride: stride size for the first pooling layer. Not used
if first_pool_size is None. if first_pool_size is None.
second_pool_size: Pool size to be used for the second pooling layer.
second_pool_stride: stride size for the final pooling layer
block_sizes: A list containing n values, where n is the number of sets of block_sizes: A list containing n values, where n is the number of sets of
block layers desired. Each value should be the number of blocks in the block layers desired. Each value should be the number of blocks in the
i-th set. i-th set.
block_strides: List of integers representing the desired stride size for block_strides: List of integers representing the desired stride size for
each of the sets of block layers. Should be same length as block_sizes. each of the sets of block layers. Should be same length as block_sizes.
final_size: The expected size of the model after the second pooling. final_size: The expected size of the model after the second pooling.
version: Integer representing which version of the ResNet network to use. resnet_version: Integer representing which version of the ResNet network
See README for details. Valid values: [1, 2] to use. See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
dtype: The TensorFlow dtype to use for calculations. If not specified dtype: The TensorFlow dtype to use for calculations. If not specified
...@@ -395,19 +393,19 @@ class Model(object): ...@@ -395,19 +393,19 @@ class Model(object):
data_format = ( data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
self.resnet_version = version self.resnet_version = resnet_version
if version not in (1, 2): if resnet_version not in (1, 2):
raise ValueError( raise ValueError(
'Resnet version should be 1 or 2. See README for citations.') 'Resnet version should be 1 or 2. See README for citations.')
self.bottleneck = bottleneck self.bottleneck = bottleneck
if bottleneck: if bottleneck:
if version == 1: if resnet_version == 1:
self.block_fn = _bottleneck_block_v1 self.block_fn = _bottleneck_block_v1
else: else:
self.block_fn = _bottleneck_block_v2 self.block_fn = _bottleneck_block_v2
else: else:
if version == 1: if resnet_version == 1:
self.block_fn = _building_block_v1 self.block_fn = _building_block_v1
else: else:
self.block_fn = _building_block_v2 self.block_fn = _building_block_v2
...@@ -422,8 +420,6 @@ class Model(object): ...@@ -422,8 +420,6 @@ class Model(object):
self.conv_stride = conv_stride self.conv_stride = conv_stride
self.first_pool_size = first_pool_size self.first_pool_size = first_pool_size
self.first_pool_stride = first_pool_stride self.first_pool_stride = first_pool_stride
self.second_pool_size = second_pool_size
self.second_pool_stride = second_pool_stride
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.block_strides = block_strides self.block_strides = block_strides
self.final_size = final_size self.final_size = final_size
......
This diff is collapsed.
# Transformer Translation Model
This is an implementation of the Transformer translation model as described in the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. Based on the code provided by the authors: [Transformer code](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py) from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor).
Transformer is a neural network architecture that solves sequence to sequence problems using attention mechanisms. Unlike traditional neural seq2seq models, Transformer does not involve recurrent connections. The attention mechanism learns dependencies between tokens in two sequences. Since attention weights apply to all tokens in the sequences, the Tranformer model is able to easily capture long-distance depedencies.
Transformer's overall structure follows the standard encoder-decoder pattern. The encoder uses self-attention to compute a representation of the input sequence. The decoder generates the output sequence one token at a time, taking the encoder output and previous decoder-outputted tokens as inputs.
The model also applies embeddings on the input and output tokens, and adds a constant positional encoding. The positional encoding adds information about the position of each token.
## Contents
* [Contents](#contents)
* [Walkthrough](#walkthrough)
* [Benchmarks](#benchmarks)
* [Training times](#training-times)
* [Evaluation results](#evaluation-results)
* [Detailed instructions](#detailed-instructions)
* [Export variables (optional)](#export-variables-optional)
* [Download and preprocess datasets](#download-and-preprocess-datasets)
* [Model training and evaluation](#model-training-and-evaluation)
* [Translate using the model](#translate-using-the-model)
* [Compute official BLEU score](#compute-official-bleu-score)
* [Implementation overview](#implementation-overview)
* [Model Definition](#model-definition)
* [Model Estimator](#model-estimator)
* [Other scripts](#other-scripts)
* [Test dataset](#test-dataset)
* [Term definitions](#term-definitions)
## Walkthrough
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.
```
PARAMS=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
# Download training/evaluation datasets
python data_download.py --data_dir=$DATA_DIR
# Train the model for 10 epochs, and evaluate after every epoch.
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
# Run during training in a separate process to get continuous updates,
# or after training is complete.
tensorboard --logdir=$MODEL_DIR
# Translate some text using the trained model
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --text="hello world"
# Compute model's BLEU score using the newstest2014 dataset.
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```
## Benchmarks
### Training times
Currently, both big and base params run on a single GPU. The measurements below
are reported from running the model on a P100 GPU.
Params | batches/sec | batches per epoch | time per epoch
--- | --- | --- | ---
base | 4.8 | 83244 | 4 hr
big | 1.1 | 41365 | 10 hr
### Evaluation results
Below are the case-insensitive BLEU scores after 10 epochs.
Params | Score
--- | --- |
base | 27.7
big | 28.9
## Detailed instructions
0. ### Export variables (optional)
Export the following variables, or modify the values in each of the snippets below:
```
PARAMS=big
DATA_DIR=$HOME/transformer/data
MODEL_DIR=$HOME/transformer/model_$PARAMS
```
1. ### Download and preprocess datasets
[data_download.py](data_download.py) downloads and preprocesses the training and evaluation WMT datasets. After the data is downloaded and extracted, the training data is used to generate a vocabulary of subtokens. The evaluation and training strings are tokenized, and the resulting data is sharded, shuffled, and saved as TFRecords.
1.75GB of compressed data will be downloaded. In total, the raw files (compressed, extracted, and combined files) take up 8.4GB of disk space. The resulting TFRecord and vocabulary files are 722MB. The script takes around 40 minutes to run, with the bulk of the time spent downloading and ~15 minutes spent on preprocessing.
Command to run:
```
python data_download.py --data_dir=$DATA_DIR
```
Arguments:
* `--data_dir`: Path where the preprocessed TFRecord data, and vocab file will be saved.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
2. ### Model training and evaluation
[transformer_main.py](transformer_main.py) creates a Transformer model, and trains it using Tensorflow Estimator.
Command to run:
```
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS
```
Arguments:
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--model_dir`: Directory to save Transformer model training checkpoints.
* `--params`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
* Use the `--help` or `-h` flag to get a full list of possible arguments.
#### Customizing training schedule
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
* Training with epochs (default):
* `--train_epochs`: The total number of complete passes to make through the dataset
* `--epochs_between_eval`: The number of epochs to train between evaluations.
* Training with steps:
* `--train_steps`: sets the total number of training steps to run.
* `--steps_between_eval`: Number of training steps to run between evaluations.
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`.
Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.
#### Compute BLEU score during model evaluation
Use these flags to compute the BLEU when the model evaluates:
* `--bleu_source`: Path to file containing text to translate.
* `--bleu_ref`: Path to file containing the reference translation.
* `--bleu_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
When running `transformer_main.py`, use the flags: `--bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de`
#### Tensorboard
Training and evaluation metrics (loss, accuracy, approximate BLEU score, etc.) are logged, and can be displayed in the browser using Tensorboard.
```
tensorboard --logdir=$MODEL_DIR
```
The values are displayed at [localhost:6006](localhost:6006).
3. ### Translate using the model
[translate.py](translate.py) contains the script to use the trained model to translate input text or file. Each line in the file is translated separately.
Command to run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS --text="hello world"
```
Arguments for initializing the Subtokenizer and trained model:
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
* `--model_dir` and `--params`: These parameters are used to rebuild the trained model
Arguments for specifying what to translate:
* `--text`: Text to translate
* `--file`: Path to file containing text to translate
* `--file_out`: If `--file` is set, then this file will store the input file's translations.
To translate the newstest2014 data, run:
```
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
```
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
4. ### Compute official BLEU score
Use [compute_bleu.py](compute_bleu.py) to compute the BLEU by comparing generated translations to the reference translation.
Command to run:
```
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
```
Arguments:
* `--translation`: Path to file containing generated translations.
* `--reference`: Path to file containing reference translations.
* Use the `--help` or `-h` flag to get a full list of possible arguments.
## Implementation overview
A brief look at each component in the code:
### Model Definition
The [model](model) subdirectory contains the implementation of the Transformer model. The following files define the Transformer model and its layers:
* [transformer.py](model/transformer.py): Defines the transformer model and its encoder/decoder layer stacks.
* [embedding_layer.py](model/embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
* [attention_layer.py](model/attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
* [ffn_layer.py](model/ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
Other files:
* [beam_search.py](model/beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
* [model_params.py](model/model_params.py) contains the parameters used for the big and base models.
* [model_utils.py](model/model_utils.py) defines some helper functions used in the model (calculating padding, bias, etc.).
### Model Estimator
[transformer_main.py](model/transformer.py) creates an `Estimator` to train and evaluate the model.
Helper functions:
* [utils/dataset.py](utils/dataset.py): contains functions for creating a `dataset` that is passed to the `Estimator`.
* [utils/metrics.py](utils/metrics.py): defines metrics functions used by the `Estimator` to evaluate the
### Other scripts
Aside from the main file to train the Transformer model, we provide other scripts for using the model or downloading the data:
#### Data download and preprocessing
[data_download.py](data_download.py) downloads and extracts data, then uses `Subtokenizer` to tokenize strings into arrays of int IDs. The int arrays are converted to `tf.Examples` and saved in the `tf.RecordDataset` format.
The data is downloaded from the Workshop of Machine Transtion (WMT) [news translation task](http://www.statmt.org/wmt17/translation-task.html). The following datasets are used:
* Europarl v7
* Common Crawl corpus
* News Commentary v12
See the [download section](http://www.statmt.org/wmt17/translation-task.html#download) to explore the raw datasets. The parameters in this model are tuned to fit the English-German translation data, so the EN-DE texts are extracted from the downloaded compressed files.
The text is transformed into arrays of integer IDs using the `Subtokenizer` defined in [`utils/tokenizer.py`](util/tokenizer.py). During initialization of the `Subtokenizer`, the raw training data is used to generate a vocabulary list containing common subtokens.
The target vocabulary size of the WMT dataset is 32,768. The set of subtokens is found through binary search on the minimum number of times a subtoken appears in the data. The actual vocabulary size is 33,708, and is stored in a 324kB file.
#### Translation
Translation is defined in [translate.py](translate.py). First, `Subtokenizer` tokenizes the input. The vocabulary file is the same used to tokenize the training/eval files. Next, beam search is used to find the combination of tokens that maximizes the probability outputted by the model decoder. The tokens are then converted back to strings with `Subtokenizer`.
#### BLEU computation
[compute_bleu.py](compute_bleu.py): Implementation from [https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py).
### Test dataset
The [newstest2014 files](test_data) are extracted from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data). The raw text files are converted from the SGM format of the [WMT 2016](http://www.statmt.org/wmt16/translation-task.html) test sets.
## Term definitions
**Steps / Epochs**:
* Step: unit for processing a single batch of data
* Epoch: a complete run through the dataset
Example: Consider a training a dataset with 100 examples that is divided into 20 batches with 5 examples per batch. A single training step trains the model on one batch. After 20 training steps, the model will have trained on every batch in the dataset, or one epoch.
**Subtoken**: Words are referred to as tokens, and parts of words are referred to as 'subtokens'. For example, the word 'inclined' may be split into `['incline', 'd_']`. The '\_' indicates the end of the token. The subtoken vocabulary list is guaranteed to contain the alphabet (including numbers and special characters), so all words can be tokenized.
\ No newline at end of file
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.
Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import re
import sys
import unicodedata
# pylint: disable=g-bad-import-order
import six
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import metrics
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/'
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = tf.gfile.Open(ref_filename).read().strip().splitlines()
hyp_lines = tf.gfile.Open(hyp_filename).read().strip().splitlines()
if len(ref_lines) != len(hyp_lines):
raise ValueError("Reference and translation files have different number of "
"lines.")
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100
def main(unused_argv):
if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant:
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
print("Case-insensitive results:", score)
if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant:
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
print("Case-sensitive results:", score)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--translation", "-t", type=str, default=None, required=True,
help="[default: %(default)s] File containing translated text.",
metavar="<T>")
parser.add_argument(
"--reference", "-r", type=str, default=None, required=True,
help="[default: %(default)s] File containing reference translation",
metavar="<R>")
parser.add_argument(
"--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
nargs="*", default=None,
help="Specify one or more BLEU variants to calculate (both are "
"calculated by default. Variants: \"cased\" or \"uncased\".",
metavar="<BV>")
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py."""
import tempfile
import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu
class ComputeBleuTest(unittest.TestCase):
def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.gfile.Open(temp_file.name, 'w') as w:
w.write(text)
return temp_file.name
def test_bleu_same(self):
ref = self._create_temp_file("test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nmore tests!")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertEqual(100, cased_score)
def test_bleu_same_different_case(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)
def test_bleu_different(self):
ref = self._create_temp_file("Testing\nmore tests!")
hyp = self._create_temp_file("Dog\nCat")
uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
self.assertLess(uncased_score, 100)
self.assertLess(cased_score, 100)
def test_bleu_tokenize(self):
s = "Test0, 1 two, 3"
tokenized = compute_bleu.bleu_tokenize(s)
self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
if __name__ == "__main__":
unittest.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and preprocess WMT17 ende training and evaluation datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import random
import sys
import tarfile
import urllib
# pylint: disable=g-bad-import-order
import six
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.transformer.utils import tokenizer
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
# 1) use the flag `--search` to find the best min count or
# 2) update the _TRAIN_DATA_MIN_COUNT constant.
# min_count is the minimum number of times a token must appear in the data
# before it is added to the vocabulary. "Best min count" refers to the value
# that generates a vocabulary set that is closest in size to _TARGET_VOCAB_SIZE.
_TRAIN_DATA_SOURCES = [
{
"url": "http://data.statmt.org/wmt17/translation-task/"
"training-parallel-nc-v12.tgz",
"input": "news-commentary-v12.de-en.en",
"target": "news-commentary-v12.de-en.de",
},
{
"url": "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
"input": "commoncrawl.de-en.en",
"target": "commoncrawl.de-en.de",
},
{
"url": "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
"input": "europarl-v7.de-en.en",
"target": "europarl-v7.de-en.de",
},
]
# Use pre-defined minimum count to generate subtoken vocabulary.
_TRAIN_DATA_MIN_COUNT = 6
_EVAL_DATA_SOURCES = [
{
"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
"input": "newstest2013.en",
"target": "newstest2013.de",
}
]
# Vocabulary constants
_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
_TARGET_THRESHOLD = 327 # Accept vocabulary if size is within this threshold
VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE
# Strings to inclue in the generated files.
_PREFIX = "wmt32k"
_TRAIN_TAG = "train"
_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
# evaluation datasets are tagged as "dev" for development.
# Number of files to split train and evaluation data
_TRAIN_SHARDS = 100
_EVAL_SHARDS = 1
def find_file(path, filename, max_depth=5):
"""Returns full filepath if the file is in path or a subdirectory."""
for root, dirs, files in os.walk(path):
if filename in files:
return os.path.join(root, filename)
# Don't search past max_depth
depth = root[len(path) + 1:].count(os.sep)
if depth > max_depth:
del dirs[:] # Clear dirs
return None
###############################################################################
# Download and extraction functions
###############################################################################
def get_raw_files(raw_dir, data_source):
"""Return raw files from source. Downloads/extracts if needed.
Args:
raw_dir: string directory to store raw files
data_source: dictionary with
{"url": url of compressed dataset containing input and target files
"input": file with data in input language
"target": file with data in target language}
Returns:
dictionary with
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
"""
raw_files = {
"inputs": [],
"targets": [],
} # keys
for d in data_source:
input_file, target_file = download_and_extract(
raw_dir, d["url"], d["input"], d["target"])
raw_files["inputs"].append(input_file)
raw_files["targets"].append(target_file)
return raw_files
def download_report_hook(count, block_size, total_size):
"""Report hook for download progress.
Args:
count: current block number
block_size: block size
total_size: total size
"""
percent = int(count * block_size * 100 / total_size)
print("\r%d%%" % percent + " completed", end="\r")
def download_from_url(path, url):
"""Download content from a url.
Args:
path: string directory where file will be downloaded
url: string url
Returns:
Full path to downloaded file
"""
filename = url.split("/")[-1]
found_file = find_file(path, filename, max_depth=0)
if found_file is None:
filename = os.path.join(path, filename)
tf.logging.info("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
print()
tf.gfile.Rename(inprogress_filepath, filename)
return filename
else:
tf.logging.info("Already downloaded: %s (at %s)." % (url, found_file))
return found_file
def download_and_extract(path, url, input_filename, target_filename):
"""Extract files from downloaded compressed archive file.
Args:
path: string directory where the files will be downloaded
url: url containing the compressed input and target files
input_filename: name of file containing data in source language
target_filename: name of file containing data in target language
Returns:
Full paths to extracted input and target files.
Raises:
OSError: if the the download/extraction fails.
"""
# Check if extracted files already exist in path
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
tf.logging.info("Already downloaded and extracted %s." % url)
return input_file, target_file
# Download archive file if it doesn't already exist.
compressed_file = download_from_url(path, url)
# Extract compressed files
tf.logging.info("Extracting %s." % compressed_file)
with tarfile.open(compressed_file, "r:gz") as corpus_tar:
corpus_tar.extractall(path)
# Return filepaths of the requested files.
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
return input_file, target_file
raise OSError("Download/extraction failed for url %s to path %s" %
(url, path))
def txt_line_iterator(path):
"""Iterate through lines of file."""
with tf.gfile.Open(path) as f:
for line in f:
yield line.strip()
def compile_files(raw_dir, raw_files, tag):
"""Compile raw files into a single file for each language.
Args:
raw_dir: Directory containing downloaded raw files.
raw_files: Dict containing filenames of input and target data.
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
tag: String to append to the compiled filename.
Returns:
Full path of compiled input and target files.
"""
tf.logging.info("Compiling files with tag %s." % tag)
filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir, filename + ".lang1")
target_compiled_file = os.path.join(raw_dir, filename + ".lang2")
with tf.gfile.Open(input_compiled_file, mode="w") as input_writer:
with tf.gfile.Open(target_compiled_file, mode="w") as target_writer:
for i in range(len(raw_files["inputs"])):
input_file = raw_files["inputs"][i]
target_file = raw_files["targets"][i]
tf.logging.info("Reading files %s and %s." % (input_file, target_file))
write_file(input_writer, input_file)
write_file(target_writer, target_file)
return input_compiled_file, target_compiled_file
def write_file(writer, filename):
"""Write all of lines from file using the writer."""
for line in txt_line_iterator(filename):
writer.write(line)
writer.write("\n")
###############################################################################
# Data preprocessing
###############################################################################
def encode_and_save_files(
subtokenizer, data_dir, raw_files, tag, total_shards):
"""Save data from files as encoded Examples in TFrecord format.
Args:
subtokenizer: Subtokenizer object that will be used to encode the strings.
data_dir: The directory in which to write the examples
raw_files: A tuple of (input, target) data files. Each line in the input and
the corresponding line in target file will be saved in a tf.Example.
tag: String that will be added onto the file names.
total_shards: Number of files to divide the data into.
Returns:
List of all files produced.
"""
# Create a file for each shard.
filepaths = [shard_filename(data_dir, tag, n + 1, total_shards)
for n in range(total_shards)]
if all_exist(filepaths):
tf.logging.info("Files with tag %s already exist." % tag)
return filepaths
tf.logging.info("Saving files with tag %s." % tag)
input_file = raw_files[0]
target_file = raw_files[1]
# Write examples to each shard in round robin order.
tmp_filepaths = [fname + ".incomplete" for fname in filepaths]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
counter, shard = 0, 0
for counter, (input_line, target_line) in enumerate(zip(
txt_line_iterator(input_file), txt_line_iterator(target_file))):
if counter > 0 and counter % 100000 == 0:
tf.logging.info("\tSaving case %d." % counter)
example = dict_to_example(
{"inputs": subtokenizer.encode(input_line, add_eos=True),
"targets": subtokenizer.encode(target_line, add_eos=True)})
writers[shard].write(example.SerializeToString())
shard = (shard + 1) % total_shards
for writer in writers:
writer.close()
for tmp_name, final_name in zip(tmp_filepaths, filepaths):
tf.gfile.Rename(tmp_name, final_name)
tf.logging.info("Saved %d Examples", counter)
return filepaths
def shard_filename(path, tag, shard_num, total_shards):
"""Create filename for data shard."""
return os.path.join(
path, "%s-%s-%.5d-of-%.5d" % (_PREFIX, tag, shard_num, total_shards))
def shuffle_records(fname):
"""Shuffle records in a single file."""
tf.logging.info("Shuffling records in file %s" % fname)
# Rename file prior to shuffling
tmp_fname = fname + ".unshuffled"
tf.gfile.Rename(fname, tmp_fname)
reader = tf.python_io.tf_record_iterator(tmp_fname)
records = []
for record in reader:
records.append(record)
if len(records) % 100000 == 0:
tf.logging.info("\tRead: %d", len(records))
random.shuffle(records)
# Write shuffled records to original file name
with tf.python_io.TFRecordWriter(fname) as w:
for count, record in enumerate(records):
w.write(record)
if count > 0 and count % 100000 == 0:
tf.logging.info("\tWriting record: %d" % count)
tf.gfile.Remove(tmp_fname)
def dict_to_example(dictionary):
"""Converts a dictionary of string->int to a tf.Example."""
features = {}
for k, v in six.iteritems(dictionary):
features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
return tf.train.Example(features=tf.train.Features(feature=features))
def all_exist(filepaths):
"""Returns true if all files in the list exist."""
for fname in filepaths:
if not tf.gfile.Exists(fname):
return False
return True
def make_dir(path):
if not tf.gfile.Exists(path):
tf.logging.info("Creating directory %s" % path)
tf.gfile.MakeDirs(path)
def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model."""
tf.logging.set_verbosity(tf.logging.INFO)
make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir)
# Get paths of download/extracted training and evaluation files.
tf.logging.info("Step 1/4: Downloading data from source")
train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES)
eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES)
# Create subtokenizer based on the training files.
tf.logging.info("Step 2/4: Creating subtokenizer and building vocabulary")
train_files_flat = train_files["inputs"] + train_files["targets"]
vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
subtokenizer = tokenizer.Subtokenizer.init_from_files(
vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD,
min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
tf.logging.info("Step 3/4: Compiling training and evaluation data")
compiled_train_files = compile_files(FLAGS.raw_dir, train_files, _TRAIN_TAG)
compiled_eval_files = compile_files(FLAGS.raw_dir, eval_files, _EVAL_TAG)
# Tokenize and save data as Examples in the TFRecord format.
tf.logging.info("Step 4/4: Preprocessing and saving data")
train_tfrecord_files = encode_and_save_files(
subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG,
_TRAIN_SHARDS)
encode_and_save_files(
subtokenizer, FLAGS.data_dir, compiled_eval_files, _EVAL_TAG,
_EVAL_SHARDS)
for fname in train_tfrecord_files:
shuffle_records(fname)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory for where the "
"translate_ende_wmt32k dataset is saved.",
metavar="<DD>")
parser.add_argument(
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw",
help="[default: %(default)s] Path where the raw data will be downloaded "
"and extracted.",
metavar="<RD>")
parser.add_argument(
"--search", action="store_true",
help="If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class Attention(tf.layers.Layer):
"""Multi-headed attention layer."""
def __init__(self, hidden_size, num_heads, attention_dropout, train):
if hidden_size % num_heads != 0:
raise ValueError("Hidden size must be evenly divisible by the number of "
"heads.")
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.train = train
# Layers for linearly projecting the queries, keys, and values.
self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q")
self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k")
self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v")
self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False,
name="output_transform")
def split_heads(self, x):
"""Split x into different heads, and transpose the resulting value.
The tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
Args:
x: A tensor with shape [batch_size, length, hidden_size]
Returns:
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
"""
with tf.name_scope("split_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
# Calculate depth of last dimension after it has been split.
depth = (self.hidden_size // self.num_heads)
# Split the last dimension
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
# Transpose the result
return tf.transpose(x, [0, 2, 1, 3])
def combine_heads(self, x):
"""Combine tensor that has been split.
Args:
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
Returns:
A tensor with shape [batch_size, length, hidden_size]
"""
with tf.name_scope("combine_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[2]
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, cache=None):
"""Apply attention mechanism to x and y.
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
y: a tensor with shape [batch_size, length_y, hidden_size]
bias: attention bias that will be added to the result of the dot product.
cache: (Used during prediction) dictionary with tensors containing results
of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
# Linearly project the query (q), key (k) and value (v) using different
# learned projections. This is in preparation of splitting them into
# multiple heads. Multi-head attention uses multiple queries, keys, and
# values rather than regular attention (which uses a single q, k, v).
q = self.q_dense_layer(x)
k = self.k_dense_layer(y)
v = self.v_dense_layer(y)
if cache is not None:
# Combine cached keys and values with new keys and values.
k = tf.concat([cache["k"], k], axis=1)
v = tf.concat([cache["v"], v], axis=1)
# Update cache
cache["k"] = k
cache["v"] = v
# Split q, k, v into heads.
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# Scale q to prevent the dot product between q and k from growing too large.
depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if self.train:
weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout)
attention_output = tf.matmul(weights, v)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)
# Run the combined outputs through another linear projection layer.
attention_output = self.output_dense_layer(attention_output)
return attention_output
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, x, bias, cache=None):
return super(SelfAttention, self).call(x, x, bias, cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment