Commit 30aeec75 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2 from tensorflow/master

Sync to tensorflow-master
parents 68a18b70 78007443
# Copyright 2017 Google, Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for training adversarial text models."""
from __future__ import absolute_import
from __future__ import division
......@@ -20,6 +19,8 @@ from __future__ import print_function
import time
# Dependency imports
import numpy as np
import tensorflow as tf
......
......@@ -28,7 +28,7 @@ Pull requests:
virtualenv --system-site-packages ~/.tensorflow
source ~/.tensorflow/bin/activate
pip install --upgrade pip
pip install --upgrade tensorflow_gpu
pip install --upgrade tensorflow-gpu
```
2. At least 158GB of free disk space to download the FSNS dataset:
......@@ -65,7 +65,7 @@ To train a model using pre-trained Inception weights as initialization:
```
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar xf inception_v3_2016_08_28.tar.gz
python train.py --checkpoint_inception=inception_v3.ckpt
python train.py --checkpoint_inception=./inception_v3.ckpt
```
To fine tune the Attention OCR model using a checkpoint:
......@@ -142,6 +142,9 @@ python train.py --dataset_name=newtextdataset
Please note that eval.py will also require the same flag.
To learn how to store a data in the FSNS
format please refer to the https://stackoverflow.com/a/44461910/743658.
2. Define a new dataset format. The model needs the following data to train:
- images: input images, shape [batch_size x H x W x 3];
......@@ -176,4 +179,4 @@ The main difference between this version and the version used in the paper - for
the paper we used a distributed training with 50 GPU (K80) workers (asynchronous
updates), the provided checkpoint was created using this code after ~6 days of
training on a single GPU (Titan X) (it reached 81% after 24 hours of training),
the coordinate encoding is missing TODO(alexgorban@).
the coordinate encoding is disabled by default.
......@@ -55,6 +55,10 @@ SequenceLossParams = collections.namedtuple('SequenceLossParams', [
'label_smoothing', 'ignore_nulls', 'average_across_timesteps'
])
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [
'enabled'
])
def _dict_to_array(id_to_char, default_character):
num_char_classes = max(id_to_char.keys()) + 1
......@@ -162,7 +166,8 @@ class Model(object):
SequenceLossParams(
label_smoothing=0.1,
ignore_nulls=True,
average_across_timesteps=False)
average_across_timesteps=False),
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False)
}
def set_mparam(self, function, **kwargs):
......@@ -293,6 +298,30 @@ class Model(object):
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length))
return ids, log_prob, scores
def encode_coordinates_fn(self, net):
"""Adds one-hot encoding of coordinates to different views in the networks.
For each "pixel" of a feature map it adds a onehot encoded x and y
coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a tensor with the same height and width, but altered feature_size.
"""
mparams = self._mparams['encode_coordinates_fn']
if mparams.enabled:
batch_size, h, w, _ = net.shape.as_list()
x, y = tf.meshgrid(tf.range(w), tf.range(h))
w_loc = slim.one_hot_encoding(x, num_classes=w)
h_loc = slim.one_hot_encoding(y, num_classes=h)
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
return tf.concat([net, loc], 3)
else:
return net
def create_base(self,
images,
labels_one_hot,
......@@ -324,6 +353,9 @@ class Model(object):
]
logging.debug('Conv tower: %s', nets[0])
nets = [self.encode_coordinates_fn(net) for net in nets]
logging.debug('Conv tower w/ encoded coordinates: %s', nets[0])
net = self.pool_views_fn(nets)
logging.debug('Pooled views: %s', net)
......
......@@ -62,8 +62,9 @@ class ModelTest(tf.test.TestCase):
self.rng.randint(low=0, high=255,
size=self.images_shape).astype('float32'),
name='input_node')
self.fake_conv_tower_np = tf.constant(
self.rng.randn(*self.conv_tower_shape).astype('float32'))
self.fake_conv_tower_np = self.rng.randn(
*self.conv_tower_shape).astype('float32')
self.fake_conv_tower = tf.constant(self.fake_conv_tower_np)
self.fake_logits = tf.constant(
self.rng.randn(*self.chars_logit_shape).astype('float32'))
self.fake_labels = tf.constant(
......@@ -162,6 +163,87 @@ class ModelTest(tf.test.TestCase):
# This test checks that the loss function is 'runnable'.
self.assertEqual(loss_np.shape, tuple())
def encode_coordinates_alt(self, net):
"""An alternative implemenation for the encoding coordinates.
Args:
net: a tensor of shape=[batch_size, height, width, num_features]
Returns:
a list of tensors with encoded image coordinates in them.
"""
batch_size, h, w, _ = net.shape.as_list()
h_loc = [
tf.tile(
tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.constant([i]), num_classes=h), [h, 1]), [1, w])
for i in xrange(h)
]
h_loc = tf.concat([tf.expand_dims(t, 2) for t in h_loc], 2)
w_loc = [
tf.tile(
tf.contrib.layers.one_hot_encoding(tf.constant([i]), num_classes=w),
[h, 1]) for i in xrange(w)
]
w_loc = tf.concat([tf.expand_dims(t, 2) for t in w_loc], 2)
loc = tf.concat([h_loc, w_loc], 2)
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1])
return tf.concat([net, loc], 3)
def test_encoded_coordinates_have_correct_shape(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
batch_size, height, width, feature_size = self.conv_tower_shape
self.assertEqual(conv_w_coords.shape, (batch_size, height, width,
feature_size + height + width))
def test_disabled_coordinate_encoding_returns_features_unchanged(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=False)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
self.assertAllEqual(conv_w_coords, self.fake_conv_tower_np)
def test_coordinate_encoding_is_correct_for_simple_example(self):
shape = (1, 2, 3, 4) # batch_size, height, width, feature_size
fake_conv_tower = tf.constant(2 * np.ones(shape), dtype=tf.float32)
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(fake_conv_tower)
with self.test_session() as sess:
conv_w_coords = sess.run(conv_w_coords_tf)
# Original features
self.assertAllEqual(conv_w_coords[0, :, :, :4],
[[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]],
[[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]])
# Encoded coordinates
self.assertAllEqual(conv_w_coords[0, :, :, 4:],
[[[1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1]],
[[0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 1, 0, 0, 1]]])
def test_alt_implementation_of_coordinate_encoding_returns_same_values(self):
model = self.create_model()
model.set_mparam('encode_coordinates_fn', enabled=True)
conv_w_coords_tf = model.encode_coordinates_fn(self.fake_conv_tower)
conv_w_coords_alt_tf = self.encode_coordinates_alt(self.fake_conv_tower)
with self.test_session() as sess:
conv_w_coords_tf, conv_w_coords_alt_tf = sess.run(
[conv_w_coords_tf, conv_w_coords_alt_tf])
self.assertAllEqual(conv_w_coords_tf, conv_w_coords_alt_tf)
class CharsetMapperTest(tf.test.TestCase):
def test_text_corresponds_to_ids(self):
......
......@@ -145,7 +145,7 @@ def visit_count_fc(visit_count, last_visit, embed_neurons, wt_decay, fc_dropout)
on_value=10., off_value=0.)
last_visit = tf.one_hot(last_visit, depth=16, axis=1, dtype=tf.float32,
on_value=10., off_value=0.)
f = tf.concat_v2([visit_count, last_visit], 1)
f = tf.concat([visit_count, last_visit], 1)
x, _ = tf_utils.fc_network(
f, neurons=embed_neurons, wt_decay=wt_decay, name='visit_count_embed',
offset=0, batch_norm_param=None, dropout_ratio=fc_dropout,
......@@ -201,7 +201,7 @@ def combine_setup(name, combine_type, embed_img, embed_goal, num_img_neuorons=No
def preprocess_egomotion(locs, thetas):
with tf.name_scope('pre_ego'):
pre_ego = tf.concat_v2([locs, tf.sin(thetas), tf.cos(thetas)], 2)
pre_ego = tf.concat([locs, tf.sin(thetas), tf.cos(thetas)], 2)
sh = pre_ego.get_shape().as_list()
pre_ego = tf.reshape(pre_ego, [-1, sh[-1]])
return pre_ego
......
......@@ -8,7 +8,8 @@ code for the following papers:
## Organization
[Image Encoder](image_encoder/): Encoding and decoding images into their binary representation.
[Entropy Coder](entropy_coder/): Lossless compression of the binary representation.
## Contact Info
Model repository maintained by Nick Johnston ([nickj-google](https://github.com/nickj-google)).
Model repository maintained by Nick Johnston ([nmjohn](https://github.com/nmjohn)).
......@@ -102,4 +102,4 @@ pixel boundaries.
## Contact Info
Model repository maintained by Nick Johnston ([nickj-google](https://github.com/nickj-google)).
Model repository maintained by Nick Johnston ([nmjohn](https://github.com/nmjohn)).
# Domain Separation Networks
## Introduction
This is the code used for two domain adaptation papers.
The `domain_separation` directory contains code for the "Domain Separation
Networks" paper by Bousmalis K., Trigeorgis G., et al. which was presented at
NIPS 2016. The paper can be found here: https://arxiv.org/abs/1608.06019.
## Introduction
This code is the code used for the "Domain Separation Networks" paper
by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The
paper can be found here: https://arxiv.org/abs/1608.06019.
The `pixel_domain_adaptation` directory contains the code used for the
"Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial
Networks" paper by Bousmalis K., et al. (presented at CVPR 2017). The paper can
be found here: https://arxiv.org/abs/1612.05424. PixelDA aims to perform domain
adaptation by transfering the visual style of the target domain (which has few
or no labels) to a source domain (which has many labels). This is accomplished
using a Generative Adversarial Network (GAN).
## Contact
This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) (konstantinos@google.com).
The domain separation code was open-sourced
by [Konstantinos Bousmalis](https://github.com/bousmalis)
(konstantinos@google.com), while the pixel level domain adaptation code was
open-sourced by [David Dohan](https://github.com/dmrd) (ddohan@google.com).
## Installation
You will need to have the following installed on your machine before trying out the DSN code.
......@@ -16,26 +26,70 @@ You will need to have the following installed on your machine before trying out
* Bazel: https://bazel.build/
## Important Note
Although we are making the code available, you are only able to use the MNIST
provider for now. We will soon provide a script to download and convert MNIST-M
as well. Check back here in a few weeks or wait for a relevant announcement from
[@bousmalis](https://twitter.com/bousmalis).
We are working to open source the pose estimation dataset. For now, the MNIST to
MNIST-M dataset is available. Check back here in a few weeks or wait for a
relevant announcement from [@bousmalis](https://twitter.com/bousmalis).
## Running the code for adapting MNIST to MNIST-M
In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with
domain separation (DSNs) you will need to set the directory you used to download
MNIST and MNIST-M:
## Initial setup
In order to run the MNIST to MNIST-M experiments, you will need to set the
data directory:
```
$ export DSN_DATA_DIR=/your/dir
```
Add models and models/slim to your `$PYTHONPATH`:
Add models and models/slim to your `$PYTHONPATH` (assumes $PWD is /models):
```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
```
## Getting the datasets
You can fetch the MNIST data by running
```
$ bazel run slim:download_and_convert_data -- --dataset_dir $DSN_DATA_DIR --dataset_name=mnist
```
The MNIST-M dataset is available online [here](http://bit.ly/2nrlUAJ). Once it is downloaded and extracted into your data directory, create TFRecord files by running:
```
$ bazel run domain_adaptation/datasets:download_and_convert_mnist_m -- --dataset_dir $DSN_DATA_DIR
```
# Running PixelDA from MNIST to MNIST-M
You can run PixelDA as follows (using Tensorboard to examine the results):
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_train -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m
```
And evaluation as:
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_eval -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m --target_split_name test
```
The MNIST-M results in the paper were run with the following hparams flag:
```
--hparams arch=resnet,domain_loss_weight=0.135603587834,num_training_examples=16000000,style_transfer_loss_weight=0.0113173311334,task_loss_in_g_weight=0.0100959947002,task_tower=mnist,task_tower_in_g_step=true
```
### A note on terminology/language of the code:
The components of the network can be grouped into two parts
which correspond to elements which are jointly optimized: The generator
component and the discriminator component.
The generator component takes either an image or noise vector and produces an
output image.
The discriminator component takes the generated images and the target images
and attempts to discriminate between them.
## Running DSN code for adapting MNIST to MNIST-M
Then you need to build the binaries with Bazel:
```
......
......@@ -26,10 +26,20 @@ py_library(
],
)
py_binary(
name = "download_and_convert_mnist_m",
srcs = ["download_and_convert_mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
py_binary(
name = "mnist_m",
srcs = ["mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A factory-pattern class which returns image/label pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from slim.datasets import mnist
......
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
This module downloads the MNIST-M data, uncompresses it, reads the files
that make up the MNIST-M data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
# Dependency imports
import numpy as np
from six.moves import urllib
import tensorflow as tf
from slim.datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir', None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
_IMAGE_SIZE = 32
_NUM_CHANNELS = 3
# The number of images in the training set.
_NUM_TRAIN_SAMPLES = 59001
# The number of images to be kept from the training set for the validation set.
_NUM_VALIDATION = 1000
# The number of images in the test set.
_NUM_TEST_SAMPLES = 9001
# Seed for repeatability.
_RANDOM_SEED = 0
# The names of the classes.
_CLASS_NAMES = [
'zero',
'one',
'two',
'three',
'four',
'five',
'size',
'seven',
'eight',
'nine',
]
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB PNG data.
self._decode_png_data = tf.placeholder(dtype=tf.string)
self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_png(sess, image_data)
return image.shape[0], image.shape[1]
def decode_png(self, sess, image_data):
image = sess.run(
self._decode_png, feed_dict={self._decode_png_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'valid'.
filenames: A list of absolute paths to png images.
filename_to_class_id: A dictionary from filenames (strings) to class ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
print('Converting the {} split.'.format(split_name))
# Train and validation splits are both in the train directory.
if split_name in ['train', 'valid']:
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train')
elif split_name == 'test':
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test')
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
output_filename = _get_output_filename(dataset_dir, split_name)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for filename in filenames:
# Read the filename:
image_data = tf.gfile.FastGFile(
os.path.join(png_directory, filename), 'r').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = filename_to_class_id[filename]
example = dataset_utils.image_to_tfexample(image_data, 'png', height,
width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def _extract_labels(label_filename):
"""Extract the labels into a dict of filenames to int labels.
Args:
labels_filename: The filename of the MNIST-M labels.
Returns:
A dictionary of filenames to int labels.
"""
print('Extracting labels from: ', label_filename)
label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
label_lines = [line.rstrip('\n').split() for line in label_file]
labels = {}
for line in label_lines:
assert len(line) == 2
labels[line[0]] = int(line[1])
return labels
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
def _get_filenames(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set PNG encoded MNIST-M images.
Returns:
A list of image file paths, relative to `dataset_dir`.
"""
photo_filenames = []
for filename in os.listdir(dataset_dir):
photo_filenames.append(filename)
return photo_filenames
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
train_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
# TODO(konstantinos): Add download and cleanup functionality
train_validation_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train'))
test_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test'))
# Divide into train and validation:
random.seed(_RANDOM_SEED)
random.shuffle(train_validation_filenames)
train_filenames = train_validation_filenames[_NUM_VALIDATION:]
validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
train_validation_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train_labels.txt'))
test_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test_labels.txt'))
# Convert the train, validation, and test sets.
_convert_dataset('train', train_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('valid', validation_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('test', test_filenames, test_filenames_to_class_ids,
dataset_dir)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('\nFinished converting the MNIST-M dataset!')
def main(_):
assert FLAGS.dataset_dir
run(FLAGS.dataset_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the MNIST-M dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow_models/domain_adaptation_/datasets/download_and_convert_mnist_m_dataset.py
"""
from __future__ import absolute_import
......@@ -20,6 +23,7 @@ from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import tensorflow as tf
from slim.datasets import dataset_utils
......
# Description:
# Contains code for domain-adaptation style transfer.
package(
default_visibility = [
":internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//domain_adaptation/...",
],
)
py_library(
name = "pixelda_preprocess",
srcs = ["pixelda_preprocess.py"],
deps = [
],
)
py_test(
name = "pixelda_preprocess_test",
srcs = ["pixelda_preprocess_test.py"],
deps = [
":pixelda_preprocess",
],
)
py_library(
name = "pixelda_model",
srcs = [
"pixelda_model.py",
"pixelda_task_towers.py",
"hparams.py",
],
deps = [
],
)
py_library(
name = "pixelda_utils",
srcs = ["pixelda_utils.py"],
deps = [
],
)
py_library(
name = "pixelda_losses",
srcs = ["pixelda_losses.py"],
deps = [
],
)
py_binary(
name = "pixelda_train",
srcs = ["pixelda_train.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_binary(
name = "pixelda_eval",
srcs = ["pixelda_eval.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
licenses(["notice"]) # Apache 2.0
py_binary(
name = "baseline_train",
srcs = ["baseline_train.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
py_binary(
name = "baseline_eval",
srcs = ["baseline_eval.py"],
deps = [
"//domain_adaptation/datasets:dataset_factory",
"//domain_adaptation/pixel_domain_adaptation:pixelda_model",
"//domain_adaptation/pixel_domain_adaptation:pixelda_preprocess",
],
)
The best baselines are obtainable via the following configuration:
## MNIST => MNIST_M
Accuracy:
MNIST-Train: 99.9
MNIST_M-Train: 63.9
MNIST_M-Valid: 63.9
MNIST_M-Test: 63.6
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 105,000
## MNIST => USPS
Accuracy:
MNIST-Train: 100.0
USPS-Train: 82.8
USPS-Valid: 82.8
USPS-Test: 78.9
Learning Rate = 0.0001
Weight Decay = 0.0
Number of Steps: 22,000
## MNIST_M => MNIST
Accuracy:
MNIST_M-Train: 100
MNIST-Train: 98.5
MNIST-Valid: 98.5
MNIST-Test: 98.1
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 604,400
## MNIST_M => MNIST_M
Accuracy:
MNIST_M-Train: 100.0
MNIST_M-Valid: 96.6
MNIST_M-Test: 96.4
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 139,400
## USPS => USPS
Accuracy:
USPS-Train: 100.0
USPS-Valid: 100.0
USPS-Test: 96.5
Learning Rate = 0.001
Weight Decay = 0.0
Number of Steps: 67,000
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evals the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_string(
'checkpoint_dir', None, 'The location of the checkpoint files.')
flags.DEFINE_string(
'eval_dir', None, 'The directory where evaluation logs are written.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_integer('eval_interval_secs', 60 * 5,
'How often (in seconds) to run evaluation.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = 0.0
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
if not tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.MakeDirs(FLAGS.eval_dir)
with tf.Graph().as_default():
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.split_name,
FLAGS.dataset_dir)
num_classes = dataset.num_classes
num_samples = dataset.num_samples
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=False)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
#####################
# Define the losses #
#####################
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
predictions = tf.reshape(tf.argmax(logits, 1), shape=[-1])
class_labels = tf.argmax(labels['classes'], 1)
metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
'Mean_Loss':
tf.contrib.metrics.streaming_mean(total_loss),
'Accuracy':
tf.contrib.metrics.streaming_accuracy(predictions,
tf.reshape(
class_labels,
shape=[-1])),
'Recall_at_5':
tf.contrib.metrics.streaming_recall_at_k(logits, class_labels, 5),
})
tf.summary.histogram('outputs/Predictions', predictions)
tf.summary.histogram('outputs/Ground_Truth', class_labels)
for name, value in metrics_to_values.iteritems():
tf.summary.scalar(name, value)
num_batches = int(math.ceil(num_samples / float(FLAGS.batch_size)))
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=metrics_to_updates.values(),
eval_interval_secs=FLAGS.eval_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Trains the classification/pose baselines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_task_towers
flags = tf.app.flags
FLAGS = flags.FLAGS
slim = tf.contrib.slim
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
flags.DEFINE_integer('task', 0, 'The task ID.')
flags.DEFINE_integer('num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then '
'the parameters are handled locally by the worker.')
flags.DEFINE_integer('batch_size', 32, 'The number of samples per batch.')
flags.DEFINE_string('dataset_name', None, 'The name of the dataset.')
flags.DEFINE_string('dataset_dir', None,
'The directory where the data is stored.')
flags.DEFINE_string('split_name', None, 'The name of the train/test split.')
flags.DEFINE_float('learning_rate', 0.001, 'The initial learning rate.')
flags.DEFINE_integer(
'learning_rate_decay_steps', 20000,
'The frequency, in steps, at which the learning rate is decayed.')
flags.DEFINE_float('learning_rate_decay_factor',
0.95,
'The factor with which the learning rate is decayed.')
flags.DEFINE_float('adam_beta1', 0.5, 'The beta1 value for the AdamOptimizer')
flags.DEFINE_float('weight_decay', 1e-5,
'The L2 coefficient on the model weights.')
flags.DEFINE_string(
'logdir', None, 'The location of the logs and checkpoints.')
flags.DEFINE_integer('save_interval_secs', 600,
'How often, in seconds, we save the model to disk.')
flags.DEFINE_integer('save_summaries_secs', 600,
'How often, in seconds, we compute the summaries.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_float(
'moving_average_decay', 0.9999,
'The amount of decay to use for moving averages.')
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = tf.contrib.training.HParams()
hparams.weight_decay_task_classifier = FLAGS.weight_decay
if FLAGS.dataset_name in ['mnist', 'mnist_m', 'usps']:
hparams.task_tower = 'mnist'
else:
raise ValueError('Unknown dataset %s' % FLAGS.dataset_name)
with tf.Graph().as_default():
with tf.device(
tf.train.replica_device_setter(FLAGS.num_ps_tasks, merge_devices=True)):
dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
FLAGS.split_name, FLAGS.dataset_dir)
num_classes = dataset.num_classes
preprocess_fn = partial(pixelda_preprocess.preprocess_classification,
is_training=True)
images, labels = dataset_factory.provide_batch(
FLAGS.dataset_name,
FLAGS.split_name,
dataset_dir=FLAGS.dataset_dir,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_preprocessing_threads=FLAGS.num_readers)
# preprocess_fn=preprocess_fn)
# Define the model
logits, _ = pixelda_task_towers.add_task_specific_model(
images, hparams, num_classes=num_classes, is_training=True)
# Define the losses
if 'classes' in labels:
one_hot_labels = labels['classes']
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
tf.summary.scalar('losses/Classification_Loss', loss)
else:
raise ValueError('Only support classification for now.')
total_loss = tf.losses.get_total_loss()
tf.summary.scalar('losses/Total_Loss', total_loss)
# Setup the moving averages
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, slim.get_or_create_global_step())
tf.add_to_collection(
tf.GraphKeys.UPDATE_OPS,
variable_averages.apply(moving_average_variables))
# Specify the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.learning_rate_decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.adam_beta1)
train_op = slim.learning.create_train_op(total_loss, optimizer)
slim.learning.train(
train_op,
FLAGS.logdir,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define model HParams."""
import tensorflow as tf
def create_hparams(hparam_string=None):
"""Create model hyperparameters. Parse nondefault from given string."""
hparams = tf.contrib.training.HParams(
# The name of the architecture to use.
arch='resnet',
lrelu_leakiness=0.2,
batch_norm_decay=0.9,
weight_decay=1e-5,
normal_init_std=0.02,
generator_kernel_size=3,
discriminator_kernel_size=3,
# Stop training after this many examples are processed
# If none, train indefinitely
num_training_examples=0,
# Apply data augmentation to datasets
# Applies only in training job
augment_source_images=False,
augment_target_images=False,
# Discriminator
# Number of filters in first layer of discriminator
num_discriminator_filters=64,
discriminator_conv_block_size=1, # How many convs to have at each size
discriminator_filter_factor=2.0, # Multiply # filters by this each layer
# Add gaussian noise with this stddev to every hidden layer of D
discriminator_noise_stddev=0.2, # lmetz: Start seeing results at >= 0.1
# If true, add this gaussian noise to input images to D as well
discriminator_image_noise=False,
discriminator_first_stride=1, # Stride in first conv of discriminator
discriminator_do_pooling=False, # If true, replace stride 2 with avg pool
discriminator_dropout_keep_prob=0.9, # keep probability for dropout
# DCGAN Generator
# Number of filters in generator decoder last layer (repeatedly halved
# from 1st layer)
num_decoder_filters=64,
# Number of filters in generator encoder 1st layer (repeatedly doubled
# after 1st layer)
num_encoder_filters=64,
# This is the shape to which the noise vector is projected (if we're
# transferring from noise).
# Write this way instead of [4, 4, 64] for hparam search flexibility
projection_shape_size=4,
projection_shape_channels=64,
# Indicates the method by which we enlarge the spatial representation
# of an image. Possible values include:
# - resize_conv: Performs a nearest neighbor resize followed by a conv.
# - conv2d_transpose: Performs a conv2d_transpose.
upsample_method='resize_conv',
# Visualization
summary_steps=500, # Output image summary every N steps
###################################
# Task Classifier Hyperparameters #
###################################
# Which task-specific prediction tower to use. Possible choices are:
# none: No task tower.
# doubling_pose_estimator: classifier + quaternion regressor.
# [conv + pool]* + FC
# Classifiers used in DSN paper:
# gtsrb: Classifier used for GTSRB
# svhn: Classifier used for SVHN
# mnist: Classifier used for MNIST
# pose_mini: Classifier + regressor used for pose_mini
task_tower='doubling_pose_estimator',
weight_decay_task_classifier=1e-5,
source_task_loss_weight=1.0,
transferred_task_loss_weight=1.0,
# Number of private layers in doubling_pose_estimator task tower
num_private_layers=2,
# The weight for the log quaternion loss we use for source and transferred
# samples of the cropped_linemod dataset.
# In the DSN work, 1/8 of the classifier weight worked well for our log
# quaternion loss
source_pose_weight=0.125 * 2.0,
transferred_pose_weight=0.125 * 1.0,
# If set to True, the style transfer network also attempts to change its
# weights to maximize the performance of the task tower. If set to False,
# then the style transfer network only attempts to change its weights to
# make the transferred images more likely according to the domain
# classifier.
task_tower_in_g_step=True,
task_loss_in_g_weight=1.0, # Weight of task loss in G
#########################################
# 'simple` generator arch model hparams #
#########################################
simple_num_conv_layers=1,
simple_conv_filters=8,
#########################
# Resnet Hyperparameters#
#########################
resnet_blocks=6, # Number of resnet blocks
resnet_filters=64, # Number of filters per conv in resnet blocks
# If true, add original input back to result of convolutions inside the
# resnet arch. If false, it turns into a simple stack of conv/relu/BN
# layers.
resnet_residuals=True,
#######################################
# The residual / interpretable model. #
#######################################
res_int_blocks=2, # The number of residual blocks.
res_int_convs=2, # The number of conv calls inside each block.
res_int_filters=64, # The number of filters used by each convolution.
####################
# Latent variables #
####################
# if true, then generate random noise and project to input for generator
noise_channel=True,
# The number of dimensions in the input noise vector.
noise_dims=10,
# If true, then one hot encode source image class and project as an
# additional channel for the input to generator. This gives the generator
# access to the class, which may help generation performance.
condition_on_source_class=False,
########################
# Loss Hyperparameters #
########################
domain_loss_weight=1.0,
style_transfer_loss_weight=1.0,
########################################################################
# Encourages the transferred images to be similar to the source images #
# using a configurable metric. #
########################################################################
# The weight of the loss function encouraging the source and transferred
# images to be similar. If set to 0, then the loss function is not used.
transferred_similarity_loss_weight=0.0,
# The type of loss used to encourage transferred and source image
# similarity. Valid values include:
# mpse: Mean Pairwise Squared Error
# mse: Mean Squared Error
# hinged_mse: Computes the mean squared error using squared differences
# greater than hparams.transferred_similarity_max_diff
# hinged_mae: Computes the mean absolute error using absolute
# differences greater than hparams.transferred_similarity_max_diff.
transferred_similarity_loss='mpse',
# The maximum allowable difference between the source and target images.
# This value is used, in effect, to produce a hinge loss. Note that the
# range of values should be between 0 and 1.
transferred_similarity_max_diff=0.4,
################################
# Optimization Hyperparameters #
################################
learning_rate=0.001,
batch_size=32,
lr_decay_steps=20000,
lr_decay_rate=0.95,
# Recomendation from the DCGAN paper:
adam_beta1=0.5,
clip_gradient_norm=5.0,
# The number of times we run the discriminator train_op in a row.
discriminator_steps=1,
# The number of times we run the generator train_op in a row.
generator_steps=1)
if hparam_string:
tf.logging.info('Parsing command line hparams: %s', hparam_string)
hparams.parse(hparam_string)
tf.logging.info('Final parsed hparams: %s', hparams.values())
return hparams
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evaluates the PIXELDA model.
-- Compiles the model for CPU.
$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Compile the model for GPU.
$ bazel build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Runs the training.
$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \
--source_dataset=mnist \
--target_dataset=mnist_m \
--dataset_dir=/tmp/datasets/ \
--alsologtostderr
-- Visualize the results.
$ bash learning/brain/tensorboard/tensorboard.sh \
--port 2222 --logdir=/tmp/pixelda/
"""
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/pixelda/',
'Directory where the results are saved to.')
flags.DEFINE_integer('eval_interval_secs', 60,
'The frequency, in seconds, with which evaluation is run.')
flags.DEFINE_string('target_split_name', 'test',
'The name of the train/test split.')
flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.'
' Defaults to train.')
flags.DEFINE_string('source_dataset', 'mnist',
'The name of the source dataset.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string(
'dataset_dir',
'', # None,
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def run_eval(run_dir, checkpoint_dir, hparams):
"""Runs the eval loop.
Args:
run_dir: The directory where eval specific logs are placed
checkpoint_dir: The directory where the checkpoints are stored
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for checkpoint_path in slim.evaluation.checkpoints_iterator(
checkpoint_dir, FLAGS.eval_interval_secs):
with tf.Graph().as_default():
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name=FLAGS.target_split_name,
dataset_dir=FLAGS.dataset_dir)
target_images, target_labels = dataset_factory.provide_batch(
FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
target_labels['class'] = tf.argmax(target_labels['classes'], 1)
del target_labels['classes']
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name=FLAGS.source_split_name,
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Input and output datasets must have same number of classes')
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=False,
num_classes=num_target_classes)
#######################
# Metrics & Summaries #
#######################
names_to_values, names_to_updates = create_metrics(end_points,
source_labels,
target_labels, hparams)
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summarize_images(target_images, 'Target')
for name, value in names_to_values.iteritems():
tf.summary.scalar(name, value)
# Use the entire split by default
num_examples = target_dataset.num_samples
num_batches = math.ceil(num_examples / float(hparams.batch_size))
global_step = slim.get_or_create_global_step()
result = slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=run_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
final_op=names_to_values)
def to_degrees(log_quaternion_loss):
"""Converts a log quaternion distance to an angle.
Args:
log_quaternion_loss: The log quaternion distance between two
unit quaternions (or a batch of pairs of quaternions).
Returns:
The angle in degrees of the implied angle-axis representation.
"""
return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
def create_metrics(end_points, source_labels, target_labels, hparams):
"""Create metrics for the model.
Args:
end_points: A dictionary of end point name to tensor
source_labels: Labels for source images. batch_size x 1
target_labels: Labels for target images. batch_size x 1
hparams: The hyperparameters struct.
Returns:
Tuple of (names_to_values, names_to_updates), dictionaries that map a metric
name to its value and update op, respectively
"""
###########################################
# Evaluate the Domain Prediction Accuracy #
###########################################
batch_size = hparams.batch_size
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
('eval/Domain_Accuracy-Transferred'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points[
'transferred_domain_logits']))),
tf.zeros(batch_size, dtype=tf.int32)),
('eval/Domain_Accuracy-Target'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points['target_domain_logits']))),
tf.ones(batch_size, dtype=tf.int32))
})
################################
# Evaluate the task classifier #
################################
if 'source_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['source_task_logits'], 1),
source_labels['class'])
if 'transferred_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['transferred_task_logits'], 1),
source_labels['class'])
if 'target_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['target_task_logits'], 1),
target_labels['class'])
##########################################################################
# Pose data-specific losses.
##########################################################################
if 'quaternion' in source_labels.keys():
params = {}
params['use_logging'] = False
params['batch_size'] = batch_size
angle_loss_source = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'source_quaternion'], source_labels['quaternion'], params))
angle_loss_transferred = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'transferred_quaternion'], source_labels['quaternion'], params))
angle_loss_target = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'target_quaternion'], target_labels['quaternion'], params))
metric_name = 'eval/Angle_Loss-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_source)
metric_name = 'eval/Angle_Loss-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_transferred)
metric_name = 'eval/Angle_Loss-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_target)
return names_to_values, names_to_updates
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_eval(
run_dir=FLAGS.eval_dir,
checkpoint_dir=FLAGS.checkpoint_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment