Commit 356c98bd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents d31aba8a b9785623
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple example of contextual bandits simulation.
Code corresponding to:
Deep Bayesian Bandits Showdown: An Empirical Comparison of Bayesian Deep Networks
for Thompson Sampling, by Carlos Riquelme, George Tucker, and Jasper Snoek.
https://arxiv.org/abs/1802.09127
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import app
from absl import flags
import numpy as np
import os
import tensorflow as tf
from bandits.algorithms.bootstrapped_bnn_sampling import BootstrappedBNNSampling
from bandits.core.contextual_bandit import run_contextual_bandit
from bandits.data.data_sampler import sample_adult_data
from bandits.data.data_sampler import sample_census_data
from bandits.data.data_sampler import sample_covertype_data
from bandits.data.data_sampler import sample_jester_data
from bandits.data.data_sampler import sample_mushroom_data
from bandits.data.data_sampler import sample_statlog_data
from bandits.data.data_sampler import sample_stock_data
from bandits.algorithms.fixed_policy_sampling import FixedPolicySampling
from bandits.algorithms.linear_full_posterior_sampling import LinearFullPosteriorSampling
from bandits.algorithms.neural_linear_sampling import NeuralLinearPosteriorSampling
from bandits.algorithms.parameter_noise_sampling import ParameterNoiseSampling
from bandits.algorithms.posterior_bnn_sampling import PosteriorBNNSampling
from bandits.data.synthetic_data_sampler import sample_linear_data
from bandits.data.synthetic_data_sampler import sample_sparse_linear_data
from bandits.data.synthetic_data_sampler import sample_wheel_bandit_data
from bandits.algorithms.uniform_sampling import UniformSampling
# Set up your file routes to the data files.
base_route = os.getcwd()
data_route = 'contextual_bandits/datasets'
FLAGS = flags.FLAGS
FLAGS.set_default('alsologtostderr', True)
flags.DEFINE_string('logdir', '/tmp/bandits/', 'Base directory to save output')
flags.DEFINE_string(
'mushroom_data',
os.path.join(base_route, data_route, 'mushroom.data'),
'Directory where Mushroom data is stored.')
flags.DEFINE_string(
'financial_data',
os.path.join(base_route, data_route, 'raw_stock_contexts'),
'Directory where Financial data is stored.')
flags.DEFINE_string(
'jester_data',
os.path.join(base_route, data_route, 'jester_data_40jokes_19181users.npy'),
'Directory where Jester data is stored.')
flags.DEFINE_string(
'statlog_data',
os.path.join(base_route, data_route, 'shuttle.trn'),
'Directory where Statlog data is stored.')
flags.DEFINE_string(
'adult_data',
os.path.join(base_route, data_route, 'adult.full'),
'Directory where Adult data is stored.')
flags.DEFINE_string(
'covertype_data',
os.path.join(base_route, data_route, 'covtype.data'),
'Directory where Covertype data is stored.')
flags.DEFINE_string(
'census_data',
os.path.join(base_route, data_route, 'USCensus1990.data.txt'),
'Directory where Census data is stored.')
def sample_data(data_type, num_contexts=None):
"""Sample data from given 'data_type'.
Args:
data_type: Dataset from which to sample.
num_contexts: Number of contexts to sample.
Returns:
dataset: Sampled matrix with rows: (context, reward_1, ..., reward_num_act).
opt_rewards: Vector of expected optimal reward for each context.
opt_actions: Vector of optimal action for each context.
num_actions: Number of available actions.
context_dim: Dimension of each context.
"""
if data_type == 'linear':
# Create linear dataset
num_actions = 8
context_dim = 10
noise_stds = [0.01 * (i + 1) for i in range(num_actions)]
dataset, _, opt_linear = sample_linear_data(num_contexts, context_dim,
num_actions, sigma=noise_stds)
opt_rewards, opt_actions = opt_linear
elif data_type == 'sparse_linear':
# Create sparse linear dataset
num_actions = 7
context_dim = 10
noise_stds = [0.01 * (i + 1) for i in range(num_actions)]
num_nnz_dims = int(context_dim / 3.0)
dataset, _, opt_sparse_linear = sample_sparse_linear_data(
num_contexts, context_dim, num_actions, num_nnz_dims, sigma=noise_stds)
opt_rewards, opt_actions = opt_sparse_linear
elif data_type == 'mushroom':
# Create mushroom dataset
num_actions = 2
context_dim = 117
file_name = FLAGS.mushroom_data
dataset, opt_mushroom = sample_mushroom_data(file_name, num_contexts)
opt_rewards, opt_actions = opt_mushroom
elif data_type == 'financial':
num_actions = 8
context_dim = 21
num_contexts = min(3713, num_contexts)
noise_stds = [0.01 * (i + 1) for i in range(num_actions)]
file_name = FLAGS.financial_data
dataset, opt_financial = sample_stock_data(file_name, context_dim,
num_actions, num_contexts,
noise_stds, shuffle_rows=True)
opt_rewards, opt_actions = opt_financial
elif data_type == 'jester':
num_actions = 8
context_dim = 32
num_contexts = min(19181, num_contexts)
file_name = FLAGS.jester_data
dataset, opt_jester = sample_jester_data(file_name, context_dim,
num_actions, num_contexts,
shuffle_rows=True,
shuffle_cols=True)
opt_rewards, opt_actions = opt_jester
elif data_type == 'statlog':
file_name = FLAGS.statlog_data
num_actions = 7
num_contexts = min(43500, num_contexts)
sampled_vals = sample_statlog_data(file_name, num_contexts,
shuffle_rows=True)
contexts, rewards, (opt_rewards, opt_actions) = sampled_vals
dataset = np.hstack((contexts, rewards))
context_dim = contexts.shape[1]
elif data_type == 'adult':
file_name = FLAGS.adult_data
num_actions = 14
num_contexts = min(45222, num_contexts)
sampled_vals = sample_adult_data(file_name, num_contexts,
shuffle_rows=True)
contexts, rewards, (opt_rewards, opt_actions) = sampled_vals
dataset = np.hstack((contexts, rewards))
context_dim = contexts.shape[1]
elif data_type == 'covertype':
file_name = FLAGS.covertype_data
num_actions = 7
num_contexts = min(150000, num_contexts)
sampled_vals = sample_covertype_data(file_name, num_contexts,
shuffle_rows=True)
contexts, rewards, (opt_rewards, opt_actions) = sampled_vals
dataset = np.hstack((contexts, rewards))
context_dim = contexts.shape[1]
elif data_type == 'census':
file_name = FLAGS.census_data
num_actions = 9
num_contexts = min(150000, num_contexts)
sampled_vals = sample_census_data(file_name, num_contexts,
shuffle_rows=True)
contexts, rewards, (opt_rewards, opt_actions) = sampled_vals
dataset = np.hstack((contexts, rewards))
context_dim = contexts.shape[1]
elif data_type == 'wheel':
delta = 0.95
num_actions = 5
context_dim = 2
mean_v = [1.0, 1.0, 1.0, 1.0, 1.2]
std_v = [0.05, 0.05, 0.05, 0.05, 0.05]
mu_large = 50
std_large = 0.01
dataset, opt_wheel = sample_wheel_bandit_data(num_contexts, delta,
mean_v, std_v,
mu_large, std_large)
opt_rewards, opt_actions = opt_wheel
return dataset, opt_rewards, opt_actions, num_actions, context_dim
def display_results(algos, opt_rewards, opt_actions, h_rewards, t_init, name):
"""Displays summary statistics of the performance of each algorithm."""
print('---------------------------------------------------')
print('---------------------------------------------------')
print('{} bandit completed after {} seconds.'.format(
name, time.time() - t_init))
print('---------------------------------------------------')
performance_pairs = []
for j, a in enumerate(algos):
performance_pairs.append((a.name, np.sum(h_rewards[:, j])))
performance_pairs = sorted(performance_pairs,
key=lambda elt: elt[1],
reverse=True)
for i, (name, reward) in enumerate(performance_pairs):
print('{:3}) {:20}| \t \t total reward = {:10}.'.format(i, name, reward))
print('---------------------------------------------------')
print('Optimal total reward = {}.'.format(np.sum(opt_rewards)))
print('Frequency of optimal actions (action, frequency):')
print([[elt, list(opt_actions).count(elt)] for elt in set(opt_actions)])
print('---------------------------------------------------')
print('---------------------------------------------------')
def main(_):
# Problem parameters
num_contexts = 2000
# Data type in {linear, sparse_linear, mushroom, financial, jester,
# statlog, adult, covertype, census, wheel}
data_type = 'mushroom'
# Create dataset
sampled_vals = sample_data(data_type, num_contexts)
dataset, opt_rewards, opt_actions, num_actions, context_dim = sampled_vals
# Define hyperparameters and algorithms
hparams = tf.contrib.training.HParams(num_actions=num_actions)
hparams_linear = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
a0=6,
b0=6,
lambda_prior=0.25,
initial_pulls=2)
hparams_rms = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
optimizer='RMS',
reset_lr=True,
lr_decay_rate=0.5,
training_freq=50,
training_epochs=100,
p=0.95,
q=3)
hparams_dropout = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
optimizer='RMS',
reset_lr=True,
lr_decay_rate=0.5,
training_freq=50,
training_epochs=100,
use_dropout=True,
keep_prob=0.80)
hparams_bbb = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
optimizer='RMS',
use_sigma_exp_transform=True,
cleared_times_trained=10,
initial_training_steps=100,
noise_sigma=0.1,
reset_lr=False,
training_freq=50,
training_epochs=100)
hparams_nlinear = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
reset_lr=True,
lr_decay_rate=0.5,
training_freq=1,
training_freq_network=50,
training_epochs=100,
a0=6,
b0=6,
lambda_prior=0.25)
hparams_nlinear2 = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
reset_lr=True,
lr_decay_rate=0.5,
training_freq=10,
training_freq_network=50,
training_epochs=100,
a0=6,
b0=6,
lambda_prior=0.25)
hparams_pnoise = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
optimizer='RMS',
reset_lr=True,
lr_decay_rate=0.5,
training_freq=50,
training_epochs=100,
noise_std=0.05,
eps=0.1,
d_samples=300,
)
hparams_alpha_div = tf.contrib.training.HParams(num_actions=num_actions,
context_dim=context_dim,
init_scale=0.3,
activation=tf.nn.relu,
layer_sizes=[50],
batch_size=512,
activate_decay=True,
initial_lr=0.1,
max_grad_norm=5.0,
show_training=False,
freq_summary=1000,
buffer_s=-1,
initial_pulls=2,
optimizer='RMS',
use_sigma_exp_transform=True,
cleared_times_trained=10,
initial_training_steps=100,
noise_sigma=0.1,
reset_lr=False,
training_freq=50,
training_epochs=100,
alpha=1.0,
k=20,
prior_variance=0.1)
hparams_gp = tf.contrib.training.HParams(num_actions=num_actions,
num_outputs=num_actions,
context_dim=context_dim,
reset_lr=False,
learn_embeddings=True,
max_num_points=1000,
show_training=False,
freq_summary=1000,
batch_size=512,
keep_fixed_after_max_obs=True,
training_freq=50,
initial_pulls=2,
training_epochs=100,
lr=0.01,
buffer_s=-1,
initial_lr=0.001,
lr_decay_rate=0.0,
optimizer='RMS',
task_latent_dim=5,
activate_decay=False)
algos = [
UniformSampling('Uniform Sampling', hparams),
UniformSampling('Uniform Sampling 2', hparams),
FixedPolicySampling('fixed1', [0.75, 0.25], hparams),
FixedPolicySampling('fixed2', [0.25, 0.75], hparams),
PosteriorBNNSampling('RMS', hparams_rms, 'RMSProp'),
PosteriorBNNSampling('Dropout', hparams_dropout, 'RMSProp'),
PosteriorBNNSampling('BBB', hparams_bbb, 'Variational'),
NeuralLinearPosteriorSampling('NeuralLinear', hparams_nlinear),
NeuralLinearPosteriorSampling('NeuralLinear2', hparams_nlinear2),
LinearFullPosteriorSampling('LinFullPost', hparams_linear),
BootstrappedBNNSampling('BootRMS', hparams_rms),
ParameterNoiseSampling('ParamNoise', hparams_pnoise),
PosteriorBNNSampling('BBAlphaDiv', hparams_alpha_div, 'AlphaDiv'),
PosteriorBNNSampling('MultitaskGP', hparams_gp, 'GP'),
]
# Run contextual bandit problem
t_init = time.time()
results = run_contextual_bandit(context_dim, num_actions, dataset, algos)
_, h_rewards = results
# Display results
display_results(algos, opt_rewards, opt_actions, h_rewards, t_init, data_type)
if __name__ == '__main__':
app.run(main)
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
## Introduction
This is the code used for two domain adaptation papers.
The `domain_separation` directory contains code for the "Domain Separation
Networks" paper by Bousmalis K., Trigeorgis G., et al. which was presented at
NIPS 2016. The paper can be found here: https://arxiv.org/abs/1608.06019.
The `pixel_domain_adaptation` directory contains the code used for the
"Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial
Networks" paper by Bousmalis K., et al. (presented at CVPR 2017). The paper can
be found here: https://arxiv.org/abs/1612.05424. PixelDA aims to perform domain
adaptation by transfering the visual style of the target domain (which has few
or no labels) to a source domain (which has many labels). This is accomplished
using a Generative Adversarial Network (GAN).
### Other implementations
* [Simplified-DSN](https://github.com/AmirHussein96/Simplified-DSN):
An unofficial implementation of the [Domain Separation Networks paper](https://arxiv.org/abs/1608.06019).
## Contact
The domain separation code was open-sourced
by [Konstantinos Bousmalis](https://github.com/bousmalis)
(konstantinos@google.com), while the pixel level domain adaptation code was
open-sourced by [David Dohan](https://github.com/dmrd) (ddohan@google.com).
## Installation
You will need to have the following installed on your machine before trying out the DSN code.
* TensorFlow 1.x: https://www.tensorflow.org/install/
* Bazel: https://bazel.build/
## Initial setup
In order to run the MNIST to MNIST-M experiments, you will need to set the
data directory:
```
$ export DSN_DATA_DIR=/your/dir
```
Add models and models/slim to your `$PYTHONPATH` (assumes $PWD is /models):
```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
```
## Getting the datasets
You can fetch the MNIST data by running
```
$ bazel run slim:download_and_convert_data -- --dataset_dir $DSN_DATA_DIR --dataset_name=mnist
```
The MNIST-M dataset is available online [here](http://bit.ly/2nrlUAJ). Once it is downloaded and extracted into your data directory, create TFRecord files by running:
```
$ bazel run domain_adaptation/datasets:download_and_convert_mnist_m -- --dataset_dir $DSN_DATA_DIR
```
# Running PixelDA from MNIST to MNIST-M
You can run PixelDA as follows (using Tensorboard to examine the results):
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_train -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m
```
And evaluation as:
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_eval -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m --target_split_name test
```
The MNIST-M results in the paper were run with the following hparams flag:
```
--hparams arch=resnet,domain_loss_weight=0.135603587834,num_training_examples=16000000,style_transfer_loss_weight=0.0113173311334,task_loss_in_g_weight=0.0100959947002,task_tower=mnist,task_tower_in_g_step=true
```
### A note on terminology/language of the code:
The components of the network can be grouped into two parts
which correspond to elements which are jointly optimized: The generator
component and the discriminator component.
The generator component takes either an image or noise vector and produces an
output image.
The discriminator component takes the generated images and the target images
and attempts to discriminate between them.
## Running DSN code for adapting MNIST to MNIST-M
Then you need to build the binaries with Bazel:
```
$ bazel build -c opt domain_adaptation/domain_separation/...
```
You can then train with the following command:
```
$ ./bazel-bin/domain_adaptation/domain_separation/dsn_train \
--similarity_loss=dann_loss \
--basic_tower=dann_mnist \
--source_dataset=mnist \
--target_dataset=mnist_m \
--learning_rate=0.0117249 \
--gamma_weight=0.251175 \
--weight_decay=1e-6 \
--layers_to_regularize=fc3 \
--nouse_separation \
--master="" \
--dataset_dir=${DSN_DATA_DIR} \
-v --use_logging
```
Evaluation can be invoked with the following command:
```
$ ./bazel-bin/domain_adaptation/domain_separation/dsn_eval \
-v --dataset mnist_m --split test --num_examples=9001 \
--dataset_dir=${DSN_DATA_DIR}
```
# Domain Adaptation Scenarios Datasets
package(
default_visibility = [
":internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//domain_adaptation/...",
],
)
py_library(
name = "dataset_factory",
srcs = ["dataset_factory.py"],
deps = [
":mnist_m",
"//slim:mnist",
],
)
py_binary(
name = "download_and_convert_mnist_m",
srcs = ["download_and_convert_mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
py_binary(
name = "mnist_m",
srcs = ["mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A factory-pattern class which returns image/label pairs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from slim.datasets import mnist
from domain_adaptation.datasets import mnist_m
slim = tf.contrib.slim
def get_dataset(dataset_name,
split_name,
dataset_dir,
file_pattern=None,
reader=None):
"""Given a dataset name and a split_name returns a Dataset.
Args:
dataset_name: String, the name of the dataset.
split_name: A train/test split name.
dataset_dir: The directory where the dataset files are stored.
file_pattern: The file pattern to use for matching the dataset source files.
reader: The subclass of tf.ReaderBase. If left as `None`, then the default
reader defined by each dataset is used.
Returns:
A tf-slim `Dataset` class.
Raises:
ValueError: if `dataset_name` isn't recognized.
"""
dataset_name_to_module = {'mnist': mnist, 'mnist_m': mnist_m}
if dataset_name not in dataset_name_to_module:
raise ValueError('Name of dataset unknown %s.' % dataset_name)
return dataset_name_to_module[dataset_name].get_split(split_name, dataset_dir,
file_pattern, reader)
def provide_batch(dataset_name, split_name, dataset_dir, num_readers,
batch_size, num_preprocessing_threads):
"""Provides a batch of images and corresponding labels.
Args:
dataset_name: String, the name of the dataset.
split_name: A train/test split name.
dataset_dir: The directory where the dataset files are stored.
num_readers: The number of readers used by DatasetDataProvider.
batch_size: The size of the batch requested.
num_preprocessing_threads: The number of preprocessing threads for
tf.train.batch.
file_pattern: The file pattern to use for matching the dataset source files.
reader: The subclass of tf.ReaderBase. If left as `None`, then the default
reader defined by each dataset is used.
Returns:
A batch of
images: tensor of [batch_size, height, width, channels].
labels: dictionary of labels.
"""
dataset = get_dataset(dataset_name, split_name, dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=20 * batch_size,
common_queue_min=10 * batch_size)
[image, label] = provider.get(['image', 'label'])
# Convert images to float32
image = tf.image.convert_image_dtype(image, tf.float32)
image -= 0.5
image *= 2
# Load the data.
labels = {}
images, labels['classes'] = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocessing_threads,
capacity=5 * batch_size)
labels['classes'] = slim.one_hot_encoding(labels['classes'],
dataset.num_classes)
# Convert mnist to RGB and 32x32 so that it can match mnist_m.
if dataset_name == 'mnist':
images = tf.image.grayscale_to_rgb(images)
images = tf.image.resize_images(images, [32, 32])
return images, labels
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
This module downloads the MNIST-M data, uncompresses it, reads the files
that make up the MNIST-M data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
# Dependency imports
import numpy as np
from six.moves import urllib
import tensorflow as tf
from slim.datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir', None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
_IMAGE_SIZE = 32
_NUM_CHANNELS = 3
# The number of images in the training set.
_NUM_TRAIN_SAMPLES = 59001
# The number of images to be kept from the training set for the validation set.
_NUM_VALIDATION = 1000
# The number of images in the test set.
_NUM_TEST_SAMPLES = 9001
# Seed for repeatability.
_RANDOM_SEED = 0
# The names of the classes.
_CLASS_NAMES = [
'zero',
'one',
'two',
'three',
'four',
'five',
'size',
'seven',
'eight',
'nine',
]
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB PNG data.
self._decode_png_data = tf.placeholder(dtype=tf.string)
self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_png(sess, image_data)
return image.shape[0], image.shape[1]
def decode_png(self, sess, image_data):
image = sess.run(
self._decode_png, feed_dict={self._decode_png_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'valid'.
filenames: A list of absolute paths to png images.
filename_to_class_id: A dictionary from filenames (strings) to class ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
print('Converting the {} split.'.format(split_name))
# Train and validation splits are both in the train directory.
if split_name in ['train', 'valid']:
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train')
elif split_name == 'test':
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test')
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
output_filename = _get_output_filename(dataset_dir, split_name)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for filename in filenames:
# Read the filename:
image_data = tf.gfile.FastGFile(
os.path.join(png_directory, filename), 'r').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = filename_to_class_id[filename]
example = dataset_utils.image_to_tfexample(image_data, 'png', height,
width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def _extract_labels(label_filename):
"""Extract the labels into a dict of filenames to int labels.
Args:
labels_filename: The filename of the MNIST-M labels.
Returns:
A dictionary of filenames to int labels.
"""
print('Extracting labels from: ', label_filename)
label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
label_lines = [line.rstrip('\n').split() for line in label_file]
labels = {}
for line in label_lines:
assert len(line) == 2
labels[line[0]] = int(line[1])
return labels
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
def _get_filenames(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set PNG encoded MNIST-M images.
Returns:
A list of image file paths, relative to `dataset_dir`.
"""
photo_filenames = []
for filename in os.listdir(dataset_dir):
photo_filenames.append(filename)
return photo_filenames
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
train_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
# TODO(konstantinos): Add download and cleanup functionality
train_validation_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train'))
test_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test'))
# Divide into train and validation:
random.seed(_RANDOM_SEED)
random.shuffle(train_validation_filenames)
train_filenames = train_validation_filenames[_NUM_VALIDATION:]
validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
train_validation_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train_labels.txt'))
test_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test_labels.txt'))
# Convert the train, validation, and test sets.
_convert_dataset('train', train_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('valid', validation_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('test', test_filenames, test_filenames_to_class_ids,
dataset_dir)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('\nFinished converting the MNIST-M dataset!')
def main(_):
assert FLAGS.dataset_dir
run(FLAGS.dataset_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides data for the MNIST-M dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow_models/domain_adaptation_/datasets/download_and_convert_mnist_m_dataset.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
import tensorflow as tf
from slim.datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'mnist_m_%s.tfrecord'
_SPLITS_TO_SIZES = {'train': 58001, 'valid': 1000, 'test': 9001}
_NUM_CLASSES = 10
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A [32 x 32 x 1] RGB image.',
'label': 'A single integer between 0 and 9',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading MNIST.
Args:
split_name: A train/test split name.
dataset_dir: The base directory of the dataset sources.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/test split.
"""
if split_name not in _SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label':
tf.FixedLenFeature(
[1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(shape=[32, 32, 3], channels=3),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=_SPLITS_TO_SIZES[split_name],
num_classes=_NUM_CLASSES,
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
labels_to_names=labels_to_names)
# Domain Separation Networks
package(
default_visibility = [
":internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//domain_adaptation/...",
],
)
py_library(
name = "models",
srcs = [
"models.py",
],
deps = [
":utils",
],
)
py_library(
name = "losses",
srcs = [
"losses.py",
],
deps = [
":grl_op_grads_py",
":grl_op_shapes_py",
":grl_ops",
":utils",
],
)
py_test(
name = "losses_test",
srcs = [
"losses_test.py",
],
deps = [
":losses",
":utils",
],
)
py_library(
name = "dsn",
srcs = [
"dsn.py",
],
deps = [
":grl_op_grads_py",
":grl_op_shapes_py",
":grl_ops",
":losses",
":models",
":utils",
],
)
py_test(
name = "dsn_test",
srcs = [
"dsn_test.py",
],
deps = [
":dsn",
],
)
py_binary(
name = "dsn_train",
srcs = [
"dsn_train.py",
],
deps = [
":dsn",
":models",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_binary(
name = "dsn_eval",
srcs = [
"dsn_eval.py",
],
deps = [
":dsn",
":models",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_test(
name = "models_test",
srcs = [
"models_test.py",
],
deps = [
":models",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_library(
name = "utils",
srcs = [
"utils.py",
],
deps = [
],
)
py_library(
name = "grl_op_grads_py",
srcs = [
"grl_op_grads.py",
],
deps = [
":grl_ops",
],
)
py_library(
name = "grl_op_shapes_py",
srcs = [
"grl_op_shapes.py",
],
deps = [
],
)
py_library(
name = "grl_ops",
srcs = ["grl_ops.py"],
data = ["_grl_ops.so"],
)
py_test(
name = "grl_ops_test",
size = "small",
srcs = ["grl_ops_test.py"],
deps = [
":grl_op_grads_py",
":grl_op_shapes_py",
":grl_ops",
],
)
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to create a DSN model and add the different losses to it.
Specifically, in this file we define the:
- Shared Encoding Similarity Loss Module, with:
- The MMD Similarity method
- The Correlation Similarity method
- The Gradient Reversal (Domain-Adversarial) method
- Difference Loss Module
- Reconstruction Loss Module
- Task Loss Module
"""
from functools import partial
import tensorflow as tf
import losses
import models
import utils
slim = tf.contrib.slim
################################################################################
# HELPER FUNCTIONS
################################################################################
def dsn_loss_coefficient(params):
"""The global_step-dependent weight that specifies when to kick in DSN losses.
Args:
params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
Returns:
A weight to that effectively enables or disables the DSN-related losses,
i.e. similarity, difference, and reconstruction losses.
"""
return tf.where(
tf.less(slim.get_or_create_global_step(),
params['domain_separation_startpoint']), 1e-10, 1.0)
################################################################################
# MODEL CREATION
################################################################################
def create_model(source_images, source_labels, domain_selection_mask,
target_images, target_labels, similarity_loss, params,
basic_tower_name):
"""Creates a DSN model.
Args:
source_images: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
hot for the number of classes.
domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
the labeled images that belong to the source domain.
target_images: images from the target domain, a tensor of size
[batch_size, height width, channels].
target_labels: a dictionary with the name, tensor pairs.
similarity_loss: The type of method to use for encouraging
the codes from the shared encoder to be similar.
params: A dictionary of parameters. Expecting 'weight_decay',
'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
'decoder_name', 'encoder_name'
basic_tower_name: the name of the tower to use for the shared encoder.
Raises:
ValueError: if the arch is not one of the available architectures.
"""
network = getattr(models, basic_tower_name)
num_classes = source_labels['classes'].get_shape().as_list()[1]
# Make sure we are using the appropriate number of classes.
network = partial(network, num_classes=num_classes)
# Add the classification/pose estimation loss to the source domain.
source_endpoints = add_task_loss(source_images, source_labels, network,
params)
if similarity_loss == 'none':
# No domain adaptation, we can stop here.
return
with tf.variable_scope('towers', reuse=True):
target_logits, target_endpoints = network(
target_images, weight_decay=params['weight_decay'], prefix='target')
# Plot target accuracy of the train set.
target_accuracy = utils.accuracy(
tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
if 'quaternions' in target_labels:
target_quaternion_loss = losses.log_quaternion_loss(
target_labels['quaternions'], target_endpoints['quaternion_pred'],
params)
tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
tf.summary.scalar('eval/Target accuracy', target_accuracy)
source_shared = source_endpoints[params['layers_to_regularize']]
target_shared = target_endpoints[params['layers_to_regularize']]
# When using the semisupervised model we include labeled target data in the
# source classifier. We do not want to include these target domain when
# we use the similarity loss.
indices = tf.range(0, source_shared.get_shape().as_list()[0])
indices = tf.boolean_mask(indices, domain_selection_mask)
add_similarity_loss(similarity_loss,
tf.gather(source_shared, indices),
tf.gather(target_shared, indices), params)
if params['use_separation']:
add_autoencoders(
source_images,
source_shared,
target_images,
target_shared,
params=params,)
def add_similarity_loss(method_name,
source_samples,
target_samples,
params,
scope=None):
"""Adds a loss encouraging the shared encoding from each domain to be similar.
Args:
method_name: the name of the encoding similarity method to use. Valid
options include `dann_loss', `mmd_loss' or `correlation_loss'.
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
params: a dictionary of parameters. Expecting 'gamma_weight'.
scope: optional name scope for summary tags.
Raises:
ValueError: if `method_name` is not recognized.
"""
weight = dsn_loss_coefficient(params) * params['gamma_weight']
method = getattr(losses, method_name)
method(source_samples, target_samples, weight, scope)
def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
"""Adds a reconstruction loss.
Args:
recon_loss_name: The name of the reconstruction loss.
images: A `Tensor` of size [batch_size, height, width, 3].
recons: A `Tensor` whose size matches `images`.
weight: A scalar coefficient for the loss.
domain: The name of the domain being reconstructed.
Raises:
ValueError: If `recon_loss_name` is not recognized.
"""
if recon_loss_name == 'sum_of_pairwise_squares':
loss_fn = tf.contrib.losses.mean_pairwise_squared_error
elif recon_loss_name == 'sum_of_squares':
loss_fn = tf.contrib.losses.mean_squared_error
else:
raise ValueError('recon_loss_name value [%s] not recognized.' %
recon_loss_name)
loss = loss_fn(recons, images, weight)
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
def add_autoencoders(source_data, source_shared, target_data, target_shared,
params):
"""Adds the encoders/decoders for our domain separation model w/ incoherence.
Args:
source_data: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_shared: a tensor with first dimension batch_size
target_data: images from the target domain, a tensor of size
[batch_size, height, width, channels]
target_shared: a tensor with first dimension batch_size
params: A dictionary of parameters. Expecting 'layers_to_regularize',
'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
'encoder_name', 'weight_decay'
"""
def normalize_images(images):
images -= tf.reduce_min(images)
return images / tf.reduce_max(images)
def concat_operation(shared_repr, private_repr):
return shared_repr + private_repr
mu = dsn_loss_coefficient(params)
# The layer to concatenate the networks at.
concat_layer = params['layers_to_regularize']
# The coefficient for modulating the private/shared difference loss.
difference_loss_weight = params['beta_weight'] * mu
# The reconstruction weight.
recon_loss_weight = params['alpha_weight'] * mu
# The reconstruction loss to use.
recon_loss_name = params['recon_loss_name']
# The decoder/encoder to use.
decoder_name = params['decoder_name']
encoder_name = params['encoder_name']
_, height, width, _ = source_data.get_shape().as_list()
code_size = source_shared.get_shape().as_list()[-1]
weight_decay = params['weight_decay']
encoder_fn = getattr(models, encoder_name)
# Target Auto-encoding.
with tf.variable_scope('source_encoder'):
source_endpoints = encoder_fn(
source_data, code_size, weight_decay=weight_decay)
with tf.variable_scope('target_encoder'):
target_endpoints = encoder_fn(
target_data, code_size, weight_decay=weight_decay)
decoder_fn = getattr(models, decoder_name)
decoder = partial(
decoder_fn,
height=height,
width=width,
channels=source_data.get_shape().as_list()[-1],
weight_decay=weight_decay)
# Source Auto-encoding.
source_private = source_endpoints[concat_layer]
target_private = target_endpoints[concat_layer]
with tf.variable_scope('decoder'):
source_recons = decoder(concat_operation(source_shared, source_private))
with tf.variable_scope('decoder', reuse=True):
source_private_recons = decoder(
concat_operation(tf.zeros_like(source_private), source_private))
source_shared_recons = decoder(
concat_operation(source_shared, tf.zeros_like(source_shared)))
with tf.variable_scope('decoder', reuse=True):
target_recons = decoder(concat_operation(target_shared, target_private))
target_shared_recons = decoder(
concat_operation(target_shared, tf.zeros_like(target_shared)))
target_private_recons = decoder(
concat_operation(tf.zeros_like(target_private), target_private))
losses.difference_loss(
source_private,
source_shared,
weight=difference_loss_weight,
name='Source')
losses.difference_loss(
target_private,
target_shared,
weight=difference_loss_weight,
name='Target')
add_reconstruction_loss(recon_loss_name, source_data, source_recons,
recon_loss_weight, 'source')
add_reconstruction_loss(recon_loss_name, target_data, target_recons,
recon_loss_weight, 'target')
# Add summaries
source_reconstructions = tf.concat(
axis=2,
values=map(normalize_images, [
source_data, source_recons, source_shared_recons,
source_private_recons
]))
target_reconstructions = tf.concat(
axis=2,
values=map(normalize_images, [
target_data, target_recons, target_shared_recons,
target_private_recons
]))
tf.summary.image(
'Source Images:Recons:RGB',
source_reconstructions[:, :, :, :3],
max_outputs=10)
tf.summary.image(
'Target Images:Recons:RGB',
target_reconstructions[:, :, :, :3],
max_outputs=10)
if source_reconstructions.get_shape().as_list()[3] == 4:
tf.summary.image(
'Source Images:Recons:Depth',
source_reconstructions[:, :, :, 3:4],
max_outputs=10)
tf.summary.image(
'Target Images:Recons:Depth',
target_reconstructions[:, :, :, 3:4],
max_outputs=10)
def add_task_loss(source_images, source_labels, basic_tower, params):
"""Adds a classification and/or pose estimation loss to the model.
Args:
source_images: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_labels: labels from the source domain, a tensor of size [batch_size].
or a tuple of (quaternions, class_labels)
basic_tower: a function that creates the single tower of the model.
params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
Returns:
The source endpoints.
Raises:
RuntimeError: if basic tower does not support pose estimation.
"""
with tf.variable_scope('towers'):
source_logits, source_endpoints = basic_tower(
source_images, weight_decay=params['weight_decay'], prefix='Source')
if 'quaternions' in source_labels: # We have pose estimation as well
if 'quaternion_pred' not in source_endpoints:
raise RuntimeError('Please use a model for estimation e.g. pose_mini')
loss = losses.log_quaternion_loss(source_labels['quaternions'],
source_endpoints['quaternion_pred'],
params)
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
quaternion_loss = loss
tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
slim.losses.add_loss(quaternion_loss * params['pose_weight'])
tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
classification_loss = tf.losses.softmax_cross_entropy(
source_labels['classes'], source_logits)
tf.summary.scalar('losses/classification_loss', classification_loss)
return source_endpoints
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
"""Evaluation for Domain Separation Networks (DSNs)."""
# pylint: enable=line-too-long
import math
import numpy as np
from six.moves import xrange
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.domain_separation import losses
from domain_adaptation.domain_separation import models
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 32,
'The number of images in each batch.')
tf.app.flags.DEFINE_string('master', '',
'BNS name of the TensorFlow master to use.')
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
'Directory where the model was written to.')
tf.app.flags.DEFINE_string(
'eval_dir', '/tmp/da/',
'Directory where we should write the tf summaries to.')
tf.app.flags.DEFINE_string('dataset_dir', None,
'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('dataset', 'mnist_m',
'Which dataset to test on: "mnist", "mnist_m".')
tf.app.flags.DEFINE_string('split', 'valid',
'Which portion to test on: "valid", "test".')
tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
'The basic tower building block.')
tf.app.flags.DEFINE_bool('enable_precision_recall', False,
'If True, precision and recall for each class will '
'be added to the metrics.')
tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
def quaternion_metric(predictions, labels):
params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
return slim.metrics.streaming_mean(logcost)
def angle_diff(true_q, pred_q):
angles = 2 * (
180.0 /
np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
return angles
def provide_batch_fn():
""" The provide_batch function to use. """
return dataset_factory.provide_batch
def main(_):
g = tf.Graph()
with g.as_default():
# Load the data.
images, labels = provide_batch_fn()(
FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
num_classes = labels['classes'].get_shape().as_list()[1]
tf.summary.image('eval_images', images, max_outputs=3)
# Define the model:
with tf.variable_scope('towers'):
basic_tower = getattr(models, FLAGS.basic_tower)
predictions, endpoints = basic_tower(
images,
num_classes=num_classes,
is_training=False,
batch_norm_params=None)
metric_names_to_values = {}
# Define the metrics:
if 'quaternions' in labels: # Also have to evaluate pose estimation!
quaternion_loss = quaternion_metric(labels['quaternions'],
endpoints['quaternion_pred'])
angle_errors, = tf.py_func(
angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
[tf.float32])
metric_names_to_values[
'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
metric_names_to_values['Quaternion Loss'] = quaternion_loss
accuracy = tf.contrib.metrics.streaming_accuracy(
tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
predictions = tf.argmax(predictions, 1)
labels = tf.argmax(labels['classes'], 1)
metric_names_to_values['Accuracy'] = accuracy
if FLAGS.enable_precision_recall:
for i in xrange(num_classes):
index_map = tf.one_hot(i, depth=num_classes)
name = 'PR/Precision_{}'.format(i)
metric_names_to_values[name] = slim.metrics.streaming_precision(
tf.gather(index_map, predictions), tf.gather(index_map, labels))
name = 'PR/Recall_{}'.format(i)
metric_names_to_values[name] = slim.metrics.streaming_recall(
tf.gather(index_map, predictions), tf.gather(index_map, labels))
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
metric_names_to_values)
# Create the summary ops such that they also print out to std output:
summary_ops = []
for metric_name, metric_value in names_to_values.iteritems():
op = tf.summary.scalar(metric_name, metric_value)
op = tf.Print(op, [metric_value], metric_name)
summary_ops.append(op)
# This ensures that we make a single pass over all of the data.
num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
# Setup the global step.
slim.get_or_create_global_step()
slim.evaluation.evaluation_loop(
FLAGS.master,
checkpoint_dir=FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
summary_op=tf.summary.merge(summary_ops))
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DSN model assembly functions."""
import numpy as np
import tensorflow as tf
import dsn
class HelperFunctionsTest(tf.test.TestCase):
def testBasicDomainSeparationStartPoint(self):
with self.test_session() as sess:
# Test for when global_step < domain_separation_startpoint
step = tf.contrib.slim.get_or_create_global_step()
sess.run(tf.global_variables_initializer()) # global_step = 0
params = {'domain_separation_startpoint': 2}
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
self.assertAlmostEqual(weight_np, 1e-10)
step_op = tf.assign_add(step, 1)
step_np = sess.run(step_op) # global_step = 1
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
self.assertAlmostEqual(weight_np, 1e-10)
# Test for when global_step >= domain_separation_startpoint
step_np = sess.run(step_op) # global_step = 2
tf.logging.info(step_np)
weight = dsn.dsn_loss_coefficient(params)
weight_np = sess.run(weight)
self.assertAlmostEqual(weight_np, 1.0)
class DsnModelAssemblyTest(tf.test.TestCase):
def _testBuildDefaultModel(self):
images = tf.to_float(np.random.rand(32, 28, 28, 1))
labels = {}
labels['classes'] = tf.one_hot(
tf.to_int32(np.random.randint(0, 9, (32))), 10)
params = {
'use_separation': True,
'layers_to_regularize': 'fc3',
'weight_decay': 0.0,
'ps_tasks': 1,
'domain_separation_startpoint': 1,
'alpha_weight': 1,
'beta_weight': 1,
'gamma_weight': 1,
'recon_loss_name': 'sum_of_squares',
'decoder_name': 'small_decoder',
'encoder_name': 'default_encoder',
}
return images, labels, params
def testBuildModelDann(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelDannSumOfPairwiseSquares(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelDannMultiPSTasks(self):
images, labels, params = self._testBuildDefaultModel()
params['ps_tasks'] = 10
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelMmd(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'mmd_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelCorr(self):
images, labels, params = self._testBuildDefaultModel()
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'correlation_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 6)
def testBuildModelNoDomainAdaptation(self):
images, labels, params = self._testBuildDefaultModel()
params['use_separation'] = False
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 1)
self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 0)
def testBuildModelNoAdaptationWeightDecay(self):
images, labels, params = self._testBuildDefaultModel()
params['use_separation'] = False
params['weight_decay'] = 1e-5
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 1)
self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
def testBuildModelNoSeparation(self):
images, labels, params = self._testBuildDefaultModel()
params['use_separation'] = False
with self.test_session():
dsn.create_model(images, labels,
tf.cast(tf.ones([32,]), tf.bool), images, labels,
'dann_loss', params, 'dann_mnist')
loss_tensors = tf.contrib.losses.get_losses()
self.assertEqual(len(loss_tensors), 2)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training for Domain Separation Networks (DSNs)."""
from __future__ import division
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
import dsn
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 32,
'The number of images in each batch.')
tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
'Source dataset to train on.')
tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('dataset_dir', None,
'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('master', '',
'BNS name of the TensorFlow master to use.')
tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
'Directory where to write event logs.')
tf.app.flags.DEFINE_string(
'layers_to_regularize', 'fc3',
'Comma-separated list of layer names to use MMD regularization on.')
tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
'The coefficient for scaling the reconstruction '
'loss.')
tf.app.flags.DEFINE_float(
'beta_weight', 1e-6,
'The coefficient for scaling the private/shared difference loss.')
tf.app.flags.DEFINE_float(
'gamma_weight', 1e-6,
'The coefficient for scaling the shared encoding similarity loss.')
tf.app.flags.DEFINE_float('pose_weight', 0.125,
'The coefficient for scaling the pose loss.')
tf.app.flags.DEFINE_float(
'weight_decay', 1e-6,
'The coefficient for the L2 regularization applied for all weights.')
tf.app.flags.DEFINE_integer(
'save_summaries_secs', 60,
'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
'save_interval_secs', 60,
'The frequency with which the model is saved, in seconds.')
tf.app.flags.DEFINE_integer(
'max_number_of_steps', None,
'The maximum number of gradient steps. Use None to train indefinitely.')
tf.app.flags.DEFINE_integer(
'domain_separation_startpoint', 1,
'The global step to add the domain separation losses.')
tf.app.flags.DEFINE_integer(
'bipartite_assignment_top_k', 3,
'The number of top-k matches to use in bipartite matching adaptation.')
tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
tf.app.flags.DEFINE_bool('use_separation', False,
'Use our domain separation model.')
tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
tf.app.flags.DEFINE_integer(
'ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
tf.app.flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
'task', 0,
'The Task ID. This value is used when training with multiple workers to '
'identify each worker.')
tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
'The decoder to use.')
tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
'The encoder to use.')
################################################################################
# Flags that control the architecture and losses
################################################################################
tf.app.flags.DEFINE_string(
'similarity_loss', 'grl',
'The method to use for encouraging the common encoder codes to be '
'similar, one of "grl", "mmd", "corr".')
tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
'The name of the reconstruction loss.')
tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
'The basic tower building block.')
def provide_batch_fn():
""" The provide_batch function to use. """
return dataset_factory.provide_batch
def main(_):
model_params = {
'use_separation': FLAGS.use_separation,
'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
'layers_to_regularize': FLAGS.layers_to_regularize,
'alpha_weight': FLAGS.alpha_weight,
'beta_weight': FLAGS.beta_weight,
'gamma_weight': FLAGS.gamma_weight,
'pose_weight': FLAGS.pose_weight,
'recon_loss_name': FLAGS.recon_loss_name,
'decoder_name': FLAGS.decoder_name,
'encoder_name': FLAGS.encoder_name,
'weight_decay': FLAGS.weight_decay,
'batch_size': FLAGS.batch_size,
'use_logging': FLAGS.use_logging,
'ps_tasks': FLAGS.ps_tasks,
'task': FLAGS.task,
}
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
# Load the data.
source_images, source_labels = provide_batch_fn()(
FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
FLAGS.batch_size, FLAGS.num_preprocessing_threads)
target_images, target_labels = provide_batch_fn()(
FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
FLAGS.batch_size, FLAGS.num_preprocessing_threads)
# In the unsupervised case all the samples in the labeled
# domain are from the source domain.
domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
True)
# When using the semisupervised model we include labeled target data in
# the source labelled data.
if FLAGS.target_labeled_dataset != 'none':
# 1000 is the maximum number of labelled target samples that exists in
# the datasets.
target_semi_images, target_semi_labels = provide_batch_fn()(
FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
# Calculate the proportion of source domain samples in the semi-
# supervised setting, so that the proportion is set accordingly in the
# batches.
proportion = float(source_labels['num_train_samples']) / (
source_labels['num_train_samples'] +
target_semi_labels['num_train_samples'])
rnd_tensor = tf.random_uniform(
(target_semi_images.get_shape().as_list()[0],))
domain_selection_mask = rnd_tensor < proportion
source_images = tf.where(domain_selection_mask, source_images,
target_semi_images)
source_class_labels = tf.where(domain_selection_mask,
source_labels['classes'],
target_semi_labels['classes'])
if 'quaternions' in source_labels:
source_pose_labels = tf.where(domain_selection_mask,
source_labels['quaternions'],
target_semi_labels['quaternions'])
(source_images, source_class_labels, source_pose_labels,
domain_selection_mask) = tf.train.shuffle_batch(
[
source_images, source_class_labels, source_pose_labels,
domain_selection_mask
],
FLAGS.batch_size,
50000,
5000,
num_threads=1,
enqueue_many=True)
else:
(source_images, source_class_labels,
domain_selection_mask) = tf.train.shuffle_batch(
[source_images, source_class_labels, domain_selection_mask],
FLAGS.batch_size,
50000,
5000,
num_threads=1,
enqueue_many=True)
source_labels = {}
source_labels['classes'] = source_class_labels
if 'quaternions' in source_labels:
source_labels['quaternions'] = source_pose_labels
slim.get_or_create_global_step()
tf.summary.image('source_images', source_images, max_outputs=3)
tf.summary.image('target_images', target_images, max_outputs=3)
dsn.create_model(
source_images,
source_labels,
domain_selection_mask,
target_images,
target_labels,
FLAGS.similarity_loss,
model_params,
basic_tower_name=FLAGS.basic_tower)
# Configure the optimization scheme:
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate,
slim.get_or_create_global_step(),
FLAGS.decay_steps,
FLAGS.decay_rate,
staircase=True,
name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('total_loss', tf.losses.get_total_loss())
opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
tf.logging.set_verbosity(tf.logging.INFO)
# Run training.
loss_tensor = slim.learning.create_train_op(
slim.losses.get_total_loss(),
opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
slim.learning.train(
train_op=loss_tensor,
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
number_of_steps=FLAGS.max_number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in grl_ops.py."""
import tensorflow as tf
@tf.RegisterGradient("GradientReversal")
def _GradientReversalGrad(_, grad):
"""The gradients for `gradient_reversal`.
Args:
_: The `gradient_reversal` `Operation` that we are differentiating,
which we can use to find the inputs and outputs of the original op.
grad: Gradient with respect to the output of the `gradient_reversal` op.
Returns:
Gradient with respect to the input of `gradient_reversal`, which is simply
the negative of the input gradient.
"""
return tf.negative(grad)
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains the implementations of the ops registered in
// grl_ops.cc.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
// The gradient reversal op is used in domain adversarial training. It behaves
// as the identity op during forward propagation, and multiplies its input by -1
// during backward propagation.
class GradientReversalOp : public OpKernel {
public:
explicit GradientReversalOp(OpKernelConstruction* context)
: OpKernel(context) {}
// Gradient reversal op behaves as the identity op during forward
// propagation. Compute() function copied from the IdentityOp::Compute()
// function here: third_party/tensorflow/core/kernels/identity_op.h.
void Compute(OpKernelContext* context) override {
if (IsRefType(context->input_dtype(0))) {
context->forward_ref_input_to_ref_output(0, 0);
} else {
context->set_output(0, context->input(0));
}
}
};
REGISTER_KERNEL_BUILDER(Name("GradientReversal").Device(DEVICE_CPU),
GradientReversalOp);
} // namespace tensorflow
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shape inference for operators defined in grl_ops.cc."""
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains custom ops.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// This custom op is used by adversarial training.
REGISTER_OP("GradientReversal")
.Input("input: float")
.Output("output: float")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
This op copies the input to the output during forward propagation, and
negates the input during backward propagation.
input: Tensor.
output: Tensor, copied from input.
)doc");
} // namespace tensorflow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment